mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-19 21:07:07 +08:00
merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
commit
48bca9a109
@ -382,7 +382,7 @@ run_genai_perf_tests() {
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--backend "$backend" \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
|
||||
@ -843,3 +843,10 @@ steps:
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: Qwen MoE EP Test # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -61,13 +64,13 @@ def benchmark_decode(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
@ -75,14 +78,13 @@ def benchmark_decode(
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -142,11 +144,31 @@ def benchmark_decode(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
@ -158,6 +180,7 @@ def benchmark_decode(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -237,6 +260,7 @@ if __name__ == "__main__":
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -72,13 +75,15 @@ def benchmark_prefill(
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(
|
||||
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||
)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
@ -86,14 +91,13 @@ def benchmark_prefill(
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -152,11 +156,31 @@ def benchmark_prefill(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
@ -172,6 +196,7 @@ def benchmark_prefill(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -250,6 +275,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -432,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
|
||||
CUDA_MINOR="${CUDA_MINOR%%.*}"
|
||||
if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then
|
||||
git clone --recursive --shallow-submodules \
|
||||
${DEEPGEMM_GIT_REPO} deepgemm
|
||||
echo "🏗️ Building DeepGEMM"
|
||||
pushd deepgemm
|
||||
git checkout ${DEEPGEMM_GIT_REF}
|
||||
# Build DeepGEMM
|
||||
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python3 setup.py bdist_wheel
|
||||
uv pip install --system dist/*.whl
|
||||
popd
|
||||
rm -rf deepgemm
|
||||
else
|
||||
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
|
||||
fi
|
||||
BASH
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \
|
||||
&& rm /tmp/install_deepgemm.sh
|
||||
|
||||
# Install EP kernels(pplx-kernels and DeepEP), NixL
|
||||
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
|
||||
COPY tools/install_nixl.sh install_nixl.sh
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
|
||||
&& bash install_python_libraries.sh \
|
||||
&& bash install_nixl.sh --force
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-4 (<gh-pr:23327>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
|
||||
|
||||
There are other miscellaneous places hard-coding the use of `spawn`:
|
||||
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
|
||||
|
||||
Related PRs:
|
||||
|
||||
@ -284,6 +284,14 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
|
||||
|
||||
### DeepSeek-V3.1 Models (`deepseek_v31`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `deepseek-ai/DeepSeek-V3.1` (use with <gh-file:examples/tool_chat_template_deepseekv31.jinja>)
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}`
|
||||
|
||||
### Kimi-K2 Models (`kimi_k2`)
|
||||
|
||||
Supported models:
|
||||
|
||||
@ -401,6 +401,7 @@ th {
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -166,7 +166,7 @@ Processed means the values after applying all processors, including temperature
|
||||
|
||||
##### Prompt Logprobs with Prefix Caching
|
||||
|
||||
Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414).
|
||||
Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
|
||||
|
||||
#### Deprecated Features
|
||||
|
||||
|
||||
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
@ -0,0 +1,91 @@
|
||||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% if not thinking is defined %}
|
||||
{% set thinking = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{%- if ns.is_first_sp %}
|
||||
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
|
||||
{% set ns.is_first_sp = false %}
|
||||
{%- else %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{% if tools is defined and tools is not none %}
|
||||
{% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %}
|
||||
{% for tool in tools %}
|
||||
{% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %}
|
||||
{% endfor %}
|
||||
{% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
|
||||
{% endif %}
|
||||
|
||||
{{ bos_token }}{{ ns.system_prompt }}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_last_user = true -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|></think>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_first = false %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls'] %}
|
||||
{%- if not ns.is_first %}
|
||||
{%- if message['content'] is none %}
|
||||
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- else %}
|
||||
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
|
||||
{{'<think>'}}
|
||||
{%- else %}
|
||||
{{'</think>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool %}
|
||||
{{message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set content = content.split('</think>', 1)[1] -%}
|
||||
{%- endif %}
|
||||
{{content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if not thinking %}
|
||||
{{'</think>'}}
|
||||
{%- else %}
|
||||
{{'<think>'}}
|
||||
{%- endif %}
|
||||
{% endif %}
|
||||
@ -13,7 +13,7 @@ protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
openai >= 1.99.1 # For Responses API with reasoning content
|
||||
pydantic >= 2.10
|
||||
pydantic >= 2.11.7
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
|
||||
@ -6,7 +6,7 @@ torch==2.7.0
|
||||
torchvision==0.22.0
|
||||
torchaudio==2.7.0
|
||||
|
||||
triton==3.2
|
||||
triton==3.3.0
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
conch-triton-kernels==1.2.1
|
||||
conch-triton-kernels==1.2.1
|
||||
@ -742,7 +742,7 @@ pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.11.5
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albumentations
|
||||
|
||||
2
setup.py
2
setup.py
@ -695,6 +695,8 @@ setup(
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.12"],
|
||||
# Optional deps for AMD FP4 quantization support
|
||||
"petit-kernel": ["petit-kernel"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
||||
@ -8,11 +8,12 @@ import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@ -7,11 +7,13 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, GroupShape, QuantKey)
|
||||
FusionPass)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||
from vllm.platforms import current_platform
|
||||
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
group_shape=group_shape,
|
||||
symmetric=True)
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
else:
|
||||
|
||||
@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape)
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
backend = None
|
||||
|
||||
|
||||
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
"""Test model for AttentionStaticQuantPattern fusion."""
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
|
||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
||||
vllm_config: VllmConfig):
|
||||
vllm_config: VllmConfig, **kwargs):
|
||||
super().__init__()
|
||||
self.num_qo_heads = num_qo_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||
self.wscale = torch.tensor([1.0], dtype=torch.float32)
|
||||
self.scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
w: torch.Tensor):
|
||||
|
||||
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionFp8StaticQuantPattern fusion."""
|
||||
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randn(hidden_size, hidden_size).to(
|
||||
dtype=FP8_DTYPE, device=self.device).t(),
|
||||
"wscale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(input=attn_output,
|
||||
weight=w,
|
||||
weight_scale=self.wscale,
|
||||
input_scale=self.scale)
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"])
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionNvfp4QuantPattern fusion."""
|
||||
|
||||
quant_key = kNvfp4Quant
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randint(256, (hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device),
|
||||
"wscale_swizzled":
|
||||
torch.randn(hidden_size, hidden_size // 16).to(
|
||||
dtype=FP8_DTYPE, device=self.device),
|
||||
"wscale":
|
||||
torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
quant_output, output_block_scale = scaled_fp4_quant(
|
||||
attn_output, 1 / self.w["scale"])
|
||||
return cutlass_scaled_fp4_mm(a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, quant_key",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize("model_name, model_class",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel),
|
||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel)])
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
head_size: int, batch_size: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
quant_key: QuantKey, backend: _Backend,
|
||||
monkeypatch, dist_init):
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend, monkeypatch, dist_init):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||
|
||||
# Create test inputs
|
||||
hidden_size = num_qo_heads * head_size
|
||||
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
q = torch.randn(batch_size,
|
||||
num_qo_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
|
||||
|
||||
# Mark first dimension as dynamic for realistic testing
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_unfused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config_unfused)
|
||||
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused)
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v, linear_w)
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_fused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config)
|
||||
model_fused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
backend=test_backend,
|
||||
fullgraph=True)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
result_fused_1 = model_compiled(q, k, v, linear_w)
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
# After the 1st round of the forward pass, output quant scale should be
|
||||
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
||||
# reuse the loaded _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v, linear_w)
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape) for key,
|
||||
layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
||||
"Attention should have output_scale after fusion"
|
||||
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
elif quant_key.dtype == FP4_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
||||
|
||||
# Check that results are closed
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_1,
|
||||
|
||||
@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||
|
||||
108
tests/distributed/test_symm_mem_allreduce.py
Normal file
108
tests/distributed/test_symm_mem_allreduce.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
|
||||
test_size_elements = 4 * 1024 * 1024
|
||||
|
||||
|
||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||
pytest.skip("SymmMemCommunicator is not available or disabled.")
|
||||
|
||||
inp_direct_symm_mem = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||
pytest.skip(
|
||||
"SymmMemCommunicator isn't used for this world and input size."
|
||||
)
|
||||
|
||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
||||
assert out_direct_symm_mem is not None
|
||||
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||
torch.testing.assert_close(out_direct_symm_mem,
|
||||
original_inp_direct_symm_mem,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
|
||||
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
||||
inp_tensor_parallel = torch.randint(-23,
|
||||
1, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
||||
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
||||
inp_tensor_parallel)
|
||||
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
||||
torch.testing.assert_close(out_tensor_parallel,
|
||||
original_inp_tensor_parallel,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
pipeline_parallel_size):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
|
||||
# Enable SymmMemCommunicator
|
||||
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
|
||||
|
||||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
@ -74,18 +74,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
|
||||
}
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as err:
|
||||
err = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
||||
|
||||
assert str(err) == f"""openai.BadRequestError:
|
||||
Error code: 400 - {{'object': 'error',
|
||||
'message': 'truncate_prompt_tokens value
|
||||
({truncation_size})
|
||||
is greater than max_model_len ({max_model_len}).
|
||||
Please, select a smaller truncation size.',
|
||||
'type': 'BadRequestError',
|
||||
'param': None, 'code': 400}}"""
|
||||
assert err.value.status_code == 400
|
||||
error_details = err.value.response.json()["error"]
|
||||
assert error_details["type"] == "BadRequestError"
|
||||
expected_message = ("truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
assert error_details["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -6,7 +6,11 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
|
||||
QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
BATCH_SIZE = [4, 12]
|
||||
MAX_SEQ_LENS = [(1024, 4096)]
|
||||
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
|
||||
# TRTLLM Decode
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
rtol, atol = 1e-2, 2e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
if q_quant_dtype != kv_quant_dtype:
|
||||
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
||||
|
||||
max_q_len, max_kv_len = max_seq_lens
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
|
||||
# TRTLLM Prefill
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 4e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
|
||||
@ -3,15 +3,13 @@
|
||||
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
return model
|
||||
|
||||
|
||||
@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
],
|
||||
}
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ -221,29 +221,6 @@ def phi2_lora_files():
|
||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(**kwargs):
|
||||
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8)
|
||||
return get_model_old(**kwargs)
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
yield engine.llm_engine
|
||||
del engine
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
|
||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
@ -5,7 +5,6 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
|
||||
# Run with warmup
|
||||
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
||||
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
||||
if env.VLLM_USE_V1:
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
else:
|
||||
# No way to check V0 engine results as the calls just return None.
|
||||
pass
|
||||
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
|
||||
time_with_add_lora = await requests_processing_time(
|
||||
llm, warmup_run_requests)
|
||||
|
||||
|
||||
@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
|
||||
enable_lora=True,
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
max_loras=4,
|
||||
enable_chunked_prefill=True)
|
||||
max_loras=4)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import create_peft_lora
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
@ -35,17 +37,6 @@ DEVICES = ([
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Some tests depend on V0 internals. Since both V0 and V1 use the same
|
||||
LoRAModelManager it is okay to just test V0.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv('VLLM_USE_V1', '0')
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# Add up to capacity
|
||||
@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
tmp_path):
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
4, 2,
|
||||
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
tmp_path):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model_gate_up,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||
@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
|
||||
@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
max_loras=4,
|
||||
distributed_executor_backend="ray",
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
expected_lora_output = [
|
||||
|
||||
@ -4,17 +4,14 @@
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu_worker import Worker as V1Worker
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
NUM_LORAS = 16
|
||||
|
||||
@ -22,18 +19,11 @@ NUM_LORAS = 16
|
||||
@patch.dict(os.environ, {"RANK": "0"})
|
||||
def test_worker_apply_lora(sql_lora_files):
|
||||
|
||||
def set_active_loras(worker: Union[Worker, V1Worker],
|
||||
lora_requests: list[LoRARequest]):
|
||||
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
|
||||
lora_mapping = LoRAMapping([], [])
|
||||
if isinstance(worker, Worker):
|
||||
# v0 case
|
||||
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
|
||||
else:
|
||||
# v1 case
|
||||
worker.model_runner.lora_manager.set_active_adapters(
|
||||
lora_requests, lora_mapping)
|
||||
|
||||
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
|
||||
worker.model_runner.lora_manager.set_active_adapters(
|
||||
lora_requests, lora_mapping)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
max_cpu_loras=NUM_LORAS,
|
||||
max_loras=NUM_LORAS),
|
||||
)
|
||||
worker = worker_cls(
|
||||
worker = Worker(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
@ -340,3 +343,76 @@ def generate_data_for_nslices(
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
)
|
||||
|
||||
|
||||
def create_peft_lora(
|
||||
model: torch.nn.Module,
|
||||
save_dir: str,
|
||||
target_modules: list[str],
|
||||
rank: int = 8,
|
||||
alpha: int = 16,
|
||||
dropout: float = 0.1,
|
||||
lora_dtype: torch.dtype = torch.float16,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
lora_weights = {}
|
||||
adapter_config = {
|
||||
"peft_type": "LORA",
|
||||
"auto_mapping": None,
|
||||
"base_model_name_or_path": "dummy_model",
|
||||
"revision": None,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"inference_mode": False,
|
||||
"r": rank,
|
||||
"lora_alpha": alpha,
|
||||
"lora_dropout": dropout,
|
||||
"fan_in_fan_out": False,
|
||||
"bias": "none",
|
||||
"modules_to_save": None,
|
||||
"init_lora_weights": True,
|
||||
"layers_to_transform": None,
|
||||
"layers_pattern": None,
|
||||
"target_modules": target_modules,
|
||||
"exclude_modules": None,
|
||||
"use_rslora": False,
|
||||
"use_dora": False,
|
||||
"loftq_config": None,
|
||||
}
|
||||
|
||||
for module_name in target_modules:
|
||||
|
||||
module = model
|
||||
for attr in module_name.split("."):
|
||||
module = getattr(module, attr)
|
||||
|
||||
if hasattr(module, "input_size") and hasattr(module, "output_size"):
|
||||
|
||||
in_features = module.input_size
|
||||
out_features = module.output_size
|
||||
|
||||
elif hasattr(module, "embedding_dim") and hasattr(
|
||||
module, "num_embeddings"):
|
||||
# ParallelLMHead
|
||||
in_features = module.embedding_dim
|
||||
out_features = module.num_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to determine dimensions for module {module_name}")
|
||||
|
||||
lora_A = torch.randn(rank, in_features, dtype=lora_dtype)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)
|
||||
|
||||
lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)
|
||||
|
||||
# PEFT style
|
||||
lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
|
||||
lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B
|
||||
|
||||
config_path = os.path.join(save_dir, "adapter_config.json")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(adapter_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
weights_path = os.path.join(save_dir, "adapter_model.safetensors")
|
||||
save_file(lora_weights, weights_path)
|
||||
|
||||
return lora_weights
|
||||
|
||||
@ -11,7 +11,6 @@ from pathlib import PosixPath
|
||||
import pytest
|
||||
from transformers import (AutoModel, AutoModelForImageTextToText,
|
||||
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import identity
|
||||
@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="half",
|
||||
num_logprobs=10,
|
||||
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
|
||||
marks=[pytest.mark.skipif(
|
||||
not is_flash_attn_2_available(),
|
||||
reason="HF model needs `flash_attn` installed"
|
||||
)],
|
||||
hf_model_kwargs={"revision": "refs/pr/5"},
|
||||
),
|
||||
"phi3v": VLMTestInfo(
|
||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||
|
||||
@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
@ -465,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
|
||||
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
@ -0,0 +1,459 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seed_oss_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_oss_tool_parser(seed_oss_tokenizer):
|
||||
return SeedOssToolParser(seed_oss_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"City and country e.g. Bogotá, Colombia"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "this is the unit of temperature"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "temperature in celsius"
|
||||
}
|
||||
},
|
||||
"required": ["temperature"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
"strict": True
|
||||
}),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
# Seed-OSS tool call will not generate id
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=request) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
|
||||
result = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="his is a test response",
|
||||
current_text=model_output,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
seed_oss_tool_parser: SeedOssToolParser,
|
||||
seed_oss_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = seed_oss_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=seed_oss_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||
sample_tools, model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
"""Test incremental streaming behavior"""
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
|
||||
other_content = ''
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
# Initialize state for new tool
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
# Should only be set once
|
||||
assert tool_states[idx]["name"] is None
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx][
|
||||
"arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == expected_content
|
||||
|
||||
# Verify we got all expected tool calls
|
||||
assert len(tool_states) == len(expected_tool_calls)
|
||||
|
||||
# Verify each tool call
|
||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||
state = tool_states[idx]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == expected_tool.function.name
|
||||
|
||||
# Parse accumulated arguments
|
||||
arguments_str = state["arguments"]
|
||||
assert arguments_str is not None
|
||||
actual_args = json.loads(arguments_str)
|
||||
expected_args = json.loads(expected_tool.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
|
||||
import jsonschema
|
||||
import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
||||
|
||||
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["guidance", "xgrammar", "outlines"])
|
||||
def test_structured_output_batched_with_non_guided_requests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Don't use eager execution on TPUs because we want to test for no
|
||||
# recompilation at runtime
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
)
|
||||
|
||||
guided_prompt = (
|
||||
"Give an example JSON for an employee profile that fits this "
|
||||
"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
|
||||
non_guided_prompt = "The diameter of the Earth in kilometers is "
|
||||
|
||||
prompts = [guided_prompt, non_guided_prompt]
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=400,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
|
||||
# No max tokens, temp=0 to assert on contents
|
||||
SamplingParams(
|
||||
seed=42,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
# Free memory as soon as possible as failed assertions
|
||||
# will short circuit and not free up memory
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
for index, output in enumerate(outputs):
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
|
||||
|
||||
if index == 0:
|
||||
# First prompt is guided, expect valid JSON
|
||||
assert "\n" not in generated_text
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_json_schema)
|
||||
else:
|
||||
# Second prompt is not guided, expect valid output
|
||||
# Cannot assert on exact output, but we can expect it to be factual
|
||||
assert "12,742" in generated_text
|
||||
|
||||
# non-guided requests should not return a valid JSON here
|
||||
with pytest.raises(ValueError):
|
||||
output_json = json.loads(generated_text)
|
||||
|
||||
@ -14,6 +14,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -98,7 +100,6 @@ class FakeNixlWrapper:
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||
sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert '0' not in req_to_blocks
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
|
||||
This test verifies:
|
||||
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
|
||||
tensor metadata
|
||||
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
|
||||
block layout info
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=16,
|
||||
num_kv_heads=4,
|
||||
head_size=64)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size(
|
||||
) * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
|
||||
]
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
|
||||
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
|
||||
# Get the mock instance
|
||||
mock_wrapper_instance = mock_nixl_wrapper.return_value
|
||||
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Verify get_reg_descs was called with caches_data
|
||||
assert mock_wrapper_instance.get_reg_descs.called
|
||||
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
|
||||
assert len(caches_data) == 4
|
||||
|
||||
for i, cache_entry in enumerate(caches_data):
|
||||
base_addr, size, _tp_rank, _ = cache_entry
|
||||
assert size == expected_tensor_size, \
|
||||
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
|
||||
f"got {size}"
|
||||
assert base_addr == expected_base_addrs[i], \
|
||||
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
|
||||
f"got {base_addr}"
|
||||
|
||||
# Verify get_xfer_descs was called with blocks_data
|
||||
assert mock_wrapper_instance.get_xfer_descs.called
|
||||
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
|
||||
|
||||
# Validate blocks_data structure and size
|
||||
expected_blocks_count = 8
|
||||
assert len(blocks_data) == expected_blocks_count, \
|
||||
f"Expected {expected_blocks_count} blocks, " \
|
||||
f"got {len(blocks_data)}"
|
||||
|
||||
expected_block_len = expected_tensor_size // 2
|
||||
for i, block_entry in enumerate(blocks_data):
|
||||
block_start_addr, block_len, tp_rank = block_entry
|
||||
assert block_len == expected_block_len, \
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, " \
|
||||
f"got {block_len}"
|
||||
|
||||
@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
kv_cache_config_after_init = runner.kv_cache_config
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert id(layer_1_kv) == id(layer_0_kv)
|
||||
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
0] == layer_0
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
1] == layer_1
|
||||
|
||||
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
|
||||
@ -37,7 +37,7 @@ ALLOWED_FILES = set([
|
||||
'vllm/distributed/utils.py',
|
||||
'vllm/distributed/parallel_state.py',
|
||||
'vllm/engine/multiprocessing/client.py',
|
||||
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
|
||||
'vllm/distributed/device_communicators/all_reduce_utils.py',
|
||||
'vllm/distributed/device_communicators/shm_broadcast.py',
|
||||
'vllm/engine/multiprocessing/engine.py',
|
||||
'benchmarks/kernels/graph_machete_bench.py',
|
||||
|
||||
108
tools/install_deepgemm.sh
Executable file
108
tools/install_deepgemm.sh
Executable file
@ -0,0 +1,108 @@
|
||||
#!/bin/bash
|
||||
# Script to install DeepGEMM from source
|
||||
# This script can be used both in Docker builds and by users locally
|
||||
|
||||
set -e
|
||||
|
||||
# Default values
|
||||
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--ref)
|
||||
if [[ -z "$2" || "$2" =~ ^- ]]; then
|
||||
echo "Error: --ref requires an argument." >&2
|
||||
exit 1
|
||||
fi
|
||||
DEEPGEMM_GIT_REF="$2"
|
||||
shift 2
|
||||
;;
|
||||
--cuda-version)
|
||||
if [[ -z "$2" || "$2" =~ ^- ]]; then
|
||||
echo "Error: --cuda-version requires an argument." >&2
|
||||
exit 1
|
||||
fi
|
||||
CUDA_VERSION="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
|
||||
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Auto-detect CUDA version if not provided
|
||||
if [ -z "$CUDA_VERSION" ]; then
|
||||
if command -v nvcc >/dev/null 2>&1; then
|
||||
CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p')
|
||||
echo "Auto-detected CUDA version: $CUDA_VERSION"
|
||||
else
|
||||
echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Extract major and minor version numbers
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
|
||||
CUDA_MINOR="${CUDA_MINOR%%.*}"
|
||||
|
||||
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
|
||||
|
||||
# Check CUDA version requirement
|
||||
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
|
||||
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Installing DeepGEMM from source..."
|
||||
echo "Repository: $DEEPGEMM_GIT_REPO"
|
||||
echo "Reference: $DEEPGEMM_GIT_REF"
|
||||
|
||||
# Create a temporary directory for the build
|
||||
INSTALL_DIR=$(mktemp -d)
|
||||
trap 'rm -rf "$INSTALL_DIR"' EXIT
|
||||
|
||||
# Clone the repository
|
||||
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
|
||||
|
||||
echo "🏗️ Building DeepGEMM"
|
||||
pushd "$INSTALL_DIR/deepgemm"
|
||||
|
||||
# Checkout the specific reference
|
||||
git checkout "$DEEPGEMM_GIT_REF"
|
||||
|
||||
# Build DeepGEMM
|
||||
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python3 setup.py bdist_wheel
|
||||
|
||||
# Install the wheel
|
||||
if command -v uv >/dev/null 2>&1; then
|
||||
echo "Installing DeepGEMM wheel using uv..."
|
||||
# Use --system in Docker contexts, respect user's environment otherwise
|
||||
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
|
||||
uv pip install --system dist/*.whl
|
||||
else
|
||||
uv pip install dist/*.whl
|
||||
fi
|
||||
else
|
||||
echo "Installing DeepGEMM wheel using pip..."
|
||||
python3 -m pip install dist/*.whl
|
||||
fi
|
||||
|
||||
popd
|
||||
|
||||
echo "✅ DeepGEMM installation completed successfully"
|
||||
@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: GroupShape):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
TODO(luka) merge parameters into QuantDescriptor
|
||||
:param dtype: quantized dtype
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quant group shape.
|
||||
:param quant_key: QuantKey object that describes the quantization op
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: DifferentialFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for DifferentialFlashAttentionImpl")
|
||||
|
||||
if self.lambda_full is None:
|
||||
self.lambda_init = self.differential_flash_attention_config[
|
||||
"lambda_init"]
|
||||
|
||||
@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
attn_metadata: DualChunkFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with DualChunkFlashAttention.
|
||||
Args:
|
||||
@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
"""
|
||||
assert output is None, "Output tensor not supported for DualChunk"
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLAImplBase")
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
QuantKey, kFp8StaticTensorSym)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||
head_dim))
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: GroupShape):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
if self.use_triton_flash_attn:
|
||||
return dtype == current_platform.fp8_dtype(
|
||||
) and static and group_shape == GroupShape.PER_TENSOR
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
# Only supported in the Triton backend
|
||||
return False
|
||||
@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"fused output quantization only supported for Triton"
|
||||
" implementation in ROCMFlashAttentionImpl for now")
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused nvfp4 output quantization is not supported"
|
||||
" for ROCMFlashAttentionImpl")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
|
||||
@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
attn_metadata: "XFormersMetadata",
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersImpl")
|
||||
|
||||
@ -495,6 +495,7 @@ def unified_attention_with_output(
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@ -510,7 +511,8 @@ def unified_attention_with_output(
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale)
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
@ -522,6 +524,7 @@ def unified_attention_with_output_fake(
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@ -529,7 +532,7 @@ def unified_attention_with_output_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output"],
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -6,12 +6,13 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend, subclass_attention_metadata_builder)
|
||||
subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
def build_preprocess_fn(cm: CommonAttentionMetadata):
|
||||
return make_local_attention_virtual_batches(attention_chunk_size, cm,
|
||||
block_size)
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
common_attn_metadata = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size)
|
||||
return super().build(common_prefix_len, common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
# Dynamically create a new attention backend that wraps the
|
||||
# underlying attention backend but applies
|
||||
# `make_local_attention_virtual_batches` before calling `build(...)`
|
||||
builder_cls = subclass_attention_metadata_builder(
|
||||
name_prefix=prefix,
|
||||
builder_cls=underlying_attn_backend.get_builder_cls(),
|
||||
build_preprocess_fn=build_preprocess_fn)
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=builder_cls)
|
||||
builder_cls=ChunkedLocalAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
86
vllm/attention/layers/encoder_only_attention.py
Normal file
86
vllm/attention/layers/encoder_only_attention.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
subclass_attention_backend)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(common_prefix_len, new_common_attn_metadata,
|
||||
fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: Optional[str] = None,
|
||||
**kwargs):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(
|
||||
underlying_attn_backend)
|
||||
else:
|
||||
# in v0 encoder only attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, \
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
|
||||
super().__init__(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs)
|
||||
@ -18,7 +18,7 @@ class BeamSearchSequence:
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
# The tokens includes the prompt.
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
@ -484,7 +484,7 @@ class VllmBackend:
|
||||
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = envs.compute_hash()
|
||||
factors.append(env_hash)
|
||||
|
||||
|
||||
@ -12,7 +12,8 @@ from torch._ops import OpOverload
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class QuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
static: static quantization if True, dynamic if False
|
||||
group_shape: quantization group shape
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
|
||||
TODO(luka) use QuantDescriptor once standardized:
|
||||
https://github.com/vllm-project/vllm/issues/8913
|
||||
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
group_shape: GroupShape
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
scale=scale,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kNvfp4Quant, kStaticTensorScale)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionStaticQuantPattern:
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
Fusion for Attention+StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
|
||||
op will be removed from the graph, and its scale will be passed into
|
||||
Attention op as the `output_scale` argument.
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True,
|
||||
quant_key: QuantKey,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_key = QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric)
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
@ -57,12 +56,49 @@ class AttentionStaticQuantPattern:
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(
|
||||
self.quant_dtype, self.quant_key.static,
|
||||
self.quant_key.group_shape):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(layer, quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
@ -74,9 +110,10 @@ class AttentionStaticQuantPattern:
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None)
|
||||
attn_out_view = RESHAPE_OP(at1[1],
|
||||
[-1, self.num_heads * self.head_size])
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
@ -98,7 +135,8 @@ class AttentionStaticQuantPattern:
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale)
|
||||
output_scale=scale,
|
||||
output_block_scale=None)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
@ -114,21 +152,94 @@ class AttentionStaticQuantPattern:
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention):
|
||||
super().__init__(layer, kNvfp4Quant)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size])
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor, input_scale: torch.Tensor):
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(
|
||||
output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view)
|
||||
output = RESHAPE_OP(at2[1],
|
||||
[-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
# That would not work for the unified_attention custom op.
|
||||
with unset_fake_temporarily(), FakeTensorMode():
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128,
|
||||
round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmInductorPass):
|
||||
@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass):
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
|
||||
pattern.register_if_supported(self.patterns)
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass):
|
||||
self.end_and_log()
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
|
||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern)
|
||||
|
||||
@ -1119,9 +1119,20 @@ class ModelConfig:
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = me_quant.QUANTIZATION_METHODS
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||
"fbgemm_fp8", "compressed-tensors", "experts_int8", "quark",
|
||||
"modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
|
||||
"fp8",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
"quark",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gptq_bitblas",
|
||||
"inc",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = cast(me_quant.QuantizationMethods,
|
||||
@ -1153,6 +1164,7 @@ class ModelConfig:
|
||||
"moe_wna16",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
||||
@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: MiB // 2, # 512 KB
|
||||
8: MiB // 4, # 256 KB
|
||||
},
|
||||
"10.0": {
|
||||
2: 2 * MiB, # 2 MB
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 2 * MiB, # 2 MB
|
||||
8: 2 * MiB, # 2 MB
|
||||
}
|
||||
}
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
"10.0": {
|
||||
2: 8 * MiB, # 8 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
PyNcclCommunicator)
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce)
|
||||
from vllm.distributed.device_communicators.symm_mem import (
|
||||
SymmMemCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# currently be an MI300 series.
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
||||
device=self.device)
|
||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
if symm_mem_comm is not None and \
|
||||
symm_mem_comm.should_use_symm_mem(input_):
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
|
||||
@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -109,7 +109,13 @@ class CustomAllreduce:
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
||||
max_size)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
|
||||
111
vllm/distributed/device_communicators/symm_mem.py
Normal file
111
vllm/distributed/device_communicators/symm_mem.py
Normal file
@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
symm_mem_available = True
|
||||
except ImportError:
|
||||
symm_mem_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SymmMemCommunicator:
|
||||
_WORLD_SIZES_MULTIMEM = {
|
||||
"9.0": [4, 6, 8],
|
||||
"10.0": [6, 8],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: Union[int, str,
|
||||
torch.device]):
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.warning("SymmMemCommunicator: symmetric "
|
||||
"memory is not available.")
|
||||
return
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
torch.cuda.set_device(device)
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
"communicator is not available.",
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||
self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size]
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning("SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported.")
|
||||
return
|
||||
self.disabled = False
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype != self.dtype:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
if inp_size % 4 != 0:
|
||||
return False
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
if not self.should_use_symm_mem(inp):
|
||||
return None
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||
self.device_capability]:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
out.copy_(self.buffer[:inp.numel()].view(out.shape))
|
||||
return out
|
||||
@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args: kv_caches:
|
||||
dictionary of layer names, kv cache
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
@ -686,9 +686,6 @@ class NixlConnectorWorker:
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
if self.use_host_buffer:
|
||||
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
|
||||
assert len(self.host_xfer_buffers) == len(kv_caches), (
|
||||
@ -701,66 +698,16 @@ class NixlConnectorWorker:
|
||||
"host_xfer_buffer should not be initialized when "
|
||||
f"kv_buffer_device is {self.kv_buffer_device}")
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
if self.device_type == "tpu":
|
||||
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
|
||||
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
|
||||
# tpu (v1) kv shape per layer:
|
||||
# (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads_x_2, head_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
|
||||
elif self.device_type == "cuda":
|
||||
assert use_mla == self.use_mla
|
||||
# TODO (NickLucche) not compatible with hybrid allocator.
|
||||
# Enforce check once it goes live, as a single kv layout
|
||||
# is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} ({self.backend_name}) is not supported.")
|
||||
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
|
||||
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
|
||||
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
|
||||
self.use_host_buffer, self.num_blocks, block_shape,
|
||||
first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.device_kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
"use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
|
||||
self.use_host_buffer)
|
||||
|
||||
caches_data = []
|
||||
# With hybrid allocator, layers can share a kv cache tensor
|
||||
seen_base_addresses = []
|
||||
xfer_buffers = (self.host_xfer_buffers
|
||||
if self.use_host_buffer else kv_caches)
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
@ -770,42 +717,35 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in xfer_buffers.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla \
|
||||
or self._use_pallas_v1 or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [
|
||||
cache_or_caches
|
||||
]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
# NOTE: use tp_rank for device_id since multi-node TP
|
||||
# is rarely used.
|
||||
caches_data.append((base_addr, region_len, self.tp_rank, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
caches_data.append(
|
||||
(base_addr, tensor_size_bytes, self.tp_rank, ""))
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(xfer_buffers.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data,
|
||||
self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
@ -813,9 +753,20 @@ class NixlConnectorWorker:
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
for base_addr in seen_base_addresses:
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
@ -836,6 +787,26 @@ class NixlConnectorWorker:
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently diabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
|
||||
@ -605,7 +605,7 @@ class EngineArgs:
|
||||
**guided_decoding_kwargs["disable_additional_properties"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--reasoning-parser",
|
||||
# This choices is a special case because it's not static
|
||||
# This choice is a special case because it's not static
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
**guided_decoding_kwargs["reasoning_backend"])
|
||||
|
||||
@ -1047,7 +1047,7 @@ class EngineArgs:
|
||||
# details from the config directly
|
||||
# no user input required / expected
|
||||
if isinstance(hf_config, SpeculatorsConfig):
|
||||
# We create one since we dont create one
|
||||
# We create one since we don't create one
|
||||
self.speculative_config = {}
|
||||
self.speculative_config[
|
||||
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
|
||||
@ -1775,7 +1775,7 @@ class AsyncEngineArgs(EngineArgs):
|
||||
def add_cli_args(parser: FlexibleArgumentParser,
|
||||
async_args_only: bool = False) -> FlexibleArgumentParser:
|
||||
# Initialize plugin to update the parser, for example, The plugin may
|
||||
# adding a new kind of quantization method to --quantization argument or
|
||||
# add a new kind of quantization method to --quantization argument or
|
||||
# a new device to --device argument.
|
||||
load_general_plugins()
|
||||
if not async_args_only:
|
||||
|
||||
@ -539,7 +539,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
if request_id in self.output_queues:
|
||||
raise ValueError(f"Request {request_id} already exists")
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
# 1) Create output queue for this request.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
self.output_queues[request_id] = queue
|
||||
@ -651,7 +651,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Uses the same I/O as generate requests
|
||||
request = RPCLoadAdapterRequest(lora_request)
|
||||
|
||||
# Create output queue for this requests.
|
||||
# Create output queue for this request.
|
||||
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
|
||||
self.output_queues[request.request_id] = queue
|
||||
|
||||
|
||||
@ -1330,7 +1330,7 @@ def apply_mistral_chat_template(
|
||||
# mistral-common uses assert statements to stop processing of input
|
||||
# if input does not comply with the expected format.
|
||||
# We convert those assertion errors to ValueErrors so they can be
|
||||
# are properly caught in the preprocessing_input step
|
||||
# properly caught in the preprocessing_input step
|
||||
except (AssertionError, MistralCommonException) as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
@ -67,15 +68,27 @@ class HarmonyContext(ConversationContext):
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
# TODO(woosuk): Implement the following fields.
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
# TODO(woosuk): Implement the following fields.
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
|
||||
def _update_num_prompt_tokens(self, output: RequestOutput):
|
||||
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
|
||||
# NOTE: with built-in tools, there might be multiple rounds in
|
||||
# the conversation, with the full conversation being resent
|
||||
# as new prompt each time. Hence the sum.
|
||||
self.num_prompt_tokens += len(output.prompt_token_ids)
|
||||
|
||||
def _update_num_output_tokens(self, token_ids: Sequence[int]):
|
||||
self.num_output_tokens += len(token_ids)
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
self._update_num_prompt_tokens(output)
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self._update_num_output_tokens(output_token_ids)
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
@ -158,6 +171,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.encoding = get_encoding()
|
||||
self.last_tok = None
|
||||
self.first_tok_of_message = True
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
@ -165,8 +179,18 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
self._update_num_prompt_tokens(output)
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
# beginning of a new message
|
||||
self.first_tok_of_message = output.finished
|
||||
tok = output.outputs[0].token_ids[0]
|
||||
self.parser.process(tok)
|
||||
self._update_num_output_tokens(output.outputs[0].token_ids)
|
||||
self.last_tok = tok
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
@ -18,6 +19,7 @@ from .mistral_tool_parser import MistralToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
from .seed_oss_tool_parser import SeedOssToolParser
|
||||
from .step3_tool_parser import Step3ToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
@ -35,11 +37,13 @@ __all__ = [
|
||||
"PythonicToolParser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"DeepSeekV31ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
"Qwen3CoderToolParser",
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
]
|
||||
|
||||
367
vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py
Normal file
367
vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py
Normal file
@ -0,0 +1,367 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("deepseek_v31")
|
||||
class DeepSeekV31ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = (
|
||||
[]) # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
|
||||
|
||||
self.tool_call_start_token: str = "<|tool▁call▁begin|>"
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)")
|
||||
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<function_name>.*)<|tool▁sep|>")
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"DeepSeek-V3 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(
|
||||
model_output)
|
||||
|
||||
tool_calls = []
|
||||
for match in function_call_tuples:
|
||||
function_name, function_args = match
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args),
|
||||
))
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_calls_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_calls_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_name, tool_args = current_tool_call_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_name = current_tool_call_name_matches.groups()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
676
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
676
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
Normal file
@ -0,0 +1,676 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from qwen3coder xml parser, All rights reserved.
|
||||
# ruff: noqa: E501
|
||||
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("seed_oss")
|
||||
class SeedOssToolParser(ToolParser):
|
||||
TOOL_CALL_START = "<seed:tool_call>"
|
||||
TOOL_CALL_END = "</seed:tool_call>"
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# --- streaming state ---
|
||||
self._reset_streaming_state()
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
self.tool_call_start_token: str = self.TOOL_CALL_START
|
||||
self.tool_call_end_token: str = self.TOOL_CALL_END
|
||||
# Sentinel tokens for streaming mode
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.think_start_token: str = "<seed:think>"
|
||||
self.think_end_token: str = "</seed:think>"
|
||||
self.is_tool_call_started: bool = False
|
||||
self.is_thinking_end: bool = False
|
||||
self.failed_count: int = 0
|
||||
self._reset_streaming_state()
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Seed_Oss XML parser: tokenizer did not include "
|
||||
"<seed:tool_call> or its closing tag.")
|
||||
|
||||
tool_start_re = re.escape(self.tool_call_start_token)
|
||||
tool_end_re = re.escape(self.tool_call_end_token)
|
||||
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
|
||||
self.tool_call_regex = re.compile(
|
||||
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
|
||||
re.DOTALL)
|
||||
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
|
||||
|
||||
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
|
||||
self.__class__.__name__)
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
"""Generate a unique tool call ID."""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self):
|
||||
"""Reset all streaming state."""
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = -1
|
||||
self.current_function_name = None
|
||||
self.current_param_name = None
|
||||
self.current_param_value = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str,
|
||||
tools: Optional[list[ChatCompletionToolsParam]]
|
||||
) -> Optional[ToolCall]:
|
||||
|
||||
def get_arguments_config(func_name: str) -> dict:
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (
|
||||
hasattr(config, "function")
|
||||
and hasattr(config.function, "name")):
|
||||
continue
|
||||
if (config.type == "function"
|
||||
and config.function.name == func_name):
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.warning("Tool '%s' is not defined in the tools list.",
|
||||
func_name)
|
||||
return {}
|
||||
|
||||
def convert_param_value(param_value: str, param_name: str,
|
||||
param_config: dict, func_name: str) -> Any:
|
||||
# Handle null value for any type
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.warning(
|
||||
"Parsed parameter '%s' is not defined in "
|
||||
"the tool parameters for tool '%s', "
|
||||
"directly returning the string value.", param_name,
|
||||
func_name)
|
||||
return param_value
|
||||
|
||||
if (isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]):
|
||||
param_type = str(
|
||||
param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
if param_type in [
|
||||
"string", "str", "text", "varchar", "char", "enum"
|
||||
]:
|
||||
return param_value
|
||||
elif (param_type.startswith("int") or param_type.startswith("uint")
|
||||
or param_type.startswith("long")
|
||||
or param_type.startswith("short")
|
||||
or param_type.startswith("unsigned")):
|
||||
try:
|
||||
param_value = int(param_value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not an integer in tool "
|
||||
"'%s', degenerating to string.", param_value,
|
||||
param_name, func_name)
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith(
|
||||
"float"):
|
||||
try:
|
||||
float_param_value = float(param_value)
|
||||
param_value = float_param_value if float_param_value - int(
|
||||
float_param_value) != 0 else int(
|
||||
float_param_value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a float in tool "
|
||||
"'%s', degenerating to string.", param_value,
|
||||
param_name, func_name)
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
param_value = param_value.lower()
|
||||
if param_value not in ["true", "false"]:
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a boolean "
|
||||
"(`true` of `false`) in tool '%s', degenerating to false.",
|
||||
param_value, param_name, func_name)
|
||||
return param_value == "true"
|
||||
else:
|
||||
if param_type == "object" or param_type.startswith("dict"):
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
return param_value
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' is not a valid JSON "
|
||||
"object in tool '%s', will try other methods to parse it.",
|
||||
param_value, param_name, func_name)
|
||||
try:
|
||||
param_value = ast.literal_eval(param_value)
|
||||
except (ValueError, SyntaxError):
|
||||
logger.warning(
|
||||
"Parsed value '%s' of parameter '%s' cannot be converted via "
|
||||
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
|
||||
param_value, param_name, func_name)
|
||||
return param_value
|
||||
|
||||
# Extract function name
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = get_arguments_config(function_name)
|
||||
parameters = function_call_str[end_index + 1:]
|
||||
param_dict = {}
|
||||
for match in self.tool_call_parameter_regex.findall(parameters):
|
||||
match_text = match[0] if match[0] else match[1]
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1:])
|
||||
# Remove prefix and trailing \n
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = convert_param_value(
|
||||
param_value, param_name, param_config, function_name)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(param_dict,
|
||||
ensure_ascii=False)),
|
||||
)
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> list[str]:
|
||||
# Find all tool calls
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
|
||||
# Back-off strategy if no tool_call tags found
|
||||
if len(raw_tool_calls) == 0:
|
||||
raw_tool_calls = [model_output]
|
||||
|
||||
raw_function_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(
|
||||
self.tool_call_function_regex.findall(tool_call))
|
||||
|
||||
function_calls = [
|
||||
match[0] if match[0] else match[1] for match in raw_function_calls
|
||||
]
|
||||
return function_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# Quick check to avoid unnecessary processing
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# Check if both think start and end tokens are present
|
||||
if (self.think_start_token in model_output
|
||||
and self.think_end_token in model_output):
|
||||
# Find the position of think end token
|
||||
think_end_index = model_output.find(self.think_end_token) + len(
|
||||
self.think_end_token)
|
||||
# Extract content after think end token
|
||||
result_content = model_output[think_end_index:]
|
||||
thinking_content = model_output[:think_end_index]
|
||||
|
||||
try:
|
||||
function_calls = self._get_function_calls(result_content)
|
||||
if len(function_calls) == 0:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(function_call_str, request.tools)
|
||||
for function_call_str in function_calls
|
||||
]
|
||||
|
||||
# Populate prev_tool_call_arr for serving layer to set finish_reason
|
||||
self.prev_tool_call_arr.clear() # Clear previous calls
|
||||
for tool_call in tool_calls:
|
||||
if tool_call:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name":
|
||||
tool_call.function.name,
|
||||
"arguments":
|
||||
tool_call.function.arguments,
|
||||
})
|
||||
|
||||
# Extract content before tool calls
|
||||
tool_call_start_index = result_content.find(
|
||||
self.tool_call_start_token)
|
||||
tool_call_start_index = (
|
||||
tool_call_start_index if tool_call_start_index >= 0 else
|
||||
result_content.find(self.tool_call_prefix))
|
||||
content = thinking_content + result_content[:tool_call_start_index]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=(len(tool_calls) > 0),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# If no delta text, return None unless
|
||||
# it's an EOS token after tool calls
|
||||
if not delta_text:
|
||||
# Check if this is an EOS token after all tool calls are complete
|
||||
# We check for tool calls in the text even if is_tool_call_started
|
||||
# is False because it might have been reset after processing all tools
|
||||
if (delta_token_ids
|
||||
and self.tool_call_end_token_id not in delta_token_ids):
|
||||
# Count complete tool calls
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text))
|
||||
|
||||
# If we have completed tool calls and populated prev_tool_call_arr
|
||||
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
|
||||
# Check if all tool calls are closed
|
||||
open_calls = current_text.count(
|
||||
self.tool_call_start_token) - current_text.count(
|
||||
self.tool_call_end_token)
|
||||
if open_calls == 0:
|
||||
# Return empty delta message to allow finish_reason processing
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
# This is a regular content response that's now complete
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
# Check if this is the first call (reset state if needed)
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Update accumulated text
|
||||
self.accumulated_text = current_text
|
||||
|
||||
# Check if we need to advance to next tool
|
||||
if self.json_closed and not self.in_function:
|
||||
# Check if this tool call has ended
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
# This tool has ended, advance to next
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
|
||||
# Check if there are more tool calls
|
||||
if self.current_tool_index >= current_text.count(
|
||||
self.tool_call_start_token):
|
||||
# No more tool calls
|
||||
self.is_tool_call_started = False
|
||||
# Continue processing next tool
|
||||
return None
|
||||
|
||||
# Check if end thinking
|
||||
if (not self.is_thinking_end
|
||||
and (self.think_end_token_id in delta_token_ids
|
||||
or self.think_end_token in delta_text)):
|
||||
self.is_thinking_end = True
|
||||
|
||||
# If thinking hasn't ended yet, don't process any tool calls
|
||||
if not self.is_thinking_end:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Handle normal content before tool calls
|
||||
if not self.is_tool_call_started:
|
||||
# Check if tool call is starting
|
||||
if (self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text):
|
||||
self.is_tool_call_started = True
|
||||
# Return any content before the tool call
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[:delta_text.index(
|
||||
self.tool_call_start_token)]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
# Check if we're between tool calls - skip whitespace
|
||||
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""):
|
||||
# We just ended a tool call, skip whitespace
|
||||
return None
|
||||
# Normal content, no tool call
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Check if we're between tool calls (waiting for next one)
|
||||
# Count tool calls we've seen vs processed
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
# We're past all tool calls, shouldn't be here
|
||||
return None
|
||||
|
||||
# We're in a tool call, find the current tool call portion
|
||||
# Need to find the correct tool call based on current_tool_index
|
||||
# Only process tool calls after think_end_token
|
||||
think_end_index = current_text.find(self.think_end_token) + len(
|
||||
self.think_end_token
|
||||
) if self.think_end_token in current_text else 0
|
||||
tool_starts: list[int] = []
|
||||
idx = think_end_index
|
||||
while True:
|
||||
idx = current_text.find(self.tool_call_start_token, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
tool_starts.append(idx)
|
||||
idx += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_starts):
|
||||
# No more tool calls to process yet
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_starts[self.current_tool_index]
|
||||
# Find where this tool call ends (or current position if not ended yet)
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||
tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||
len(self.tool_call_end_token)]
|
||||
|
||||
# Looking for function header
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_end = tool_text.find(">", func_start)
|
||||
|
||||
if func_end != -1:
|
||||
# Found complete function name
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id(
|
||||
) # type: ignore
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
|
||||
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
|
||||
# This ensures finish_reason="tool_calls" even if parsing isn't complete
|
||||
already_added = any(
|
||||
tool.get("name") == self.current_function_name
|
||||
for tool in self.prev_tool_call_arr)
|
||||
if not already_added:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name": self.current_function_name,
|
||||
"arguments":
|
||||
"{}", # Placeholder, will be updated later
|
||||
})
|
||||
|
||||
# Send header with function info
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name, arguments=""),
|
||||
type="function",
|
||||
)
|
||||
])
|
||||
return None
|
||||
|
||||
# We've sent header, now handle function body
|
||||
if self.in_function:
|
||||
# Send opening brace if not sent yet
|
||||
if (not self.json_started
|
||||
and self.parameter_prefix not in delta_text):
|
||||
self.json_started = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
])
|
||||
|
||||
# Make sure json_started is set if we're processing parameters
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
|
||||
# Check for function end in accumulated text
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
# Close JSON
|
||||
self.json_closed = True
|
||||
|
||||
# Extract the complete tool call to update prev_tool_call_arr with final arguments
|
||||
# Find the function content
|
||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||
self.tool_call_prefix)
|
||||
func_content_end = tool_text.find(self.function_end_token,
|
||||
func_start)
|
||||
if func_content_end != -1:
|
||||
func_content = tool_text[func_start:func_content_end]
|
||||
# Parse to get the complete arguments
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
func_content, request.tools if request else None)
|
||||
if parsed_tool:
|
||||
# Update existing entry in prev_tool_call_arr with complete arguments
|
||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
||||
if tool.get(
|
||||
"name") == parsed_tool.function.name:
|
||||
self.prev_tool_call_arr[i]["arguments"] = (
|
||||
parsed_tool.function.arguments)
|
||||
break
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse tool arguments during streaming.",
|
||||
exc_info=True)
|
||||
|
||||
result = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
])
|
||||
|
||||
# Reset state for next tool
|
||||
self.in_function = False
|
||||
self.json_closed = True
|
||||
|
||||
return result
|
||||
|
||||
# Look for parameters
|
||||
# Count how many complete parameters we have processed
|
||||
complete_params = tool_text.count(self.parameter_end_token)
|
||||
|
||||
# Check if we should start a new parameter
|
||||
if not self.in_param and self.param_count < complete_params:
|
||||
# Find the unprocessed parameter
|
||||
# Count parameter starts
|
||||
param_starts = []
|
||||
idx = 0
|
||||
while True:
|
||||
idx = tool_text.find(self.parameter_prefix, idx)
|
||||
if idx == -1:
|
||||
break
|
||||
param_starts.append(idx)
|
||||
idx += len(self.parameter_prefix)
|
||||
|
||||
if len(param_starts) > self.param_count:
|
||||
# Process the next parameter
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" in remaining:
|
||||
# We have the complete parameter name
|
||||
name_end = remaining.find(">")
|
||||
self.current_param_name = remaining[:name_end]
|
||||
|
||||
# Find the parameter value
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
# Find where this parameter ends
|
||||
param_end_idx = value_text.find(
|
||||
self.parameter_end_token)
|
||||
if param_end_idx != -1:
|
||||
# Complete parameter found
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
# Build complete JSON fragment for this parameter
|
||||
if self.param_count == 0:
|
||||
json_fragment = (
|
||||
'"' + self.current_param_name + '": "' +
|
||||
json.dumps(param_value)[1:-1] + '"')
|
||||
else:
|
||||
json_fragment = (
|
||||
', "' + self.current_param_name + '": "' +
|
||||
json.dumps(param_value)[1:-1] + '"')
|
||||
|
||||
self.param_count += 1
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=json_fragment),
|
||||
)
|
||||
])
|
||||
|
||||
# Continue parameter value
|
||||
if self.in_param:
|
||||
if self.parameter_end_token in delta_text:
|
||||
# End of parameter
|
||||
end_idx = delta_text.find(self.parameter_end_token)
|
||||
value_chunk = delta_text[:end_idx]
|
||||
|
||||
# Skip past > if at start
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
# Calculate incremental JSON
|
||||
full_value = self.current_param_value + value_chunk
|
||||
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
|
||||
if self.current_param_value else "")
|
||||
full_escaped = json.dumps(full_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped):]
|
||||
|
||||
self.in_param = False
|
||||
self.current_param_value = ""
|
||||
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped + '"'),
|
||||
)
|
||||
])
|
||||
else:
|
||||
# Continue accumulating value
|
||||
value_chunk = delta_text
|
||||
|
||||
# Handle first chunk after param name
|
||||
if not self.current_param_value and ">" in value_chunk:
|
||||
gt_idx = value_chunk.find(">")
|
||||
value_chunk = value_chunk[gt_idx + 1:]
|
||||
|
||||
if not self.current_param_value and value_chunk.startswith(
|
||||
"\n"):
|
||||
value_chunk = value_chunk[1:]
|
||||
|
||||
if value_chunk:
|
||||
# Stream the escaped delta
|
||||
prev_escaped = (json.dumps(
|
||||
self.current_param_value)[1:-1]
|
||||
if self.current_param_value else "")
|
||||
self.current_param_value += value_chunk
|
||||
full_escaped = json.dumps(
|
||||
self.current_param_value)[1:-1]
|
||||
delta_escaped = full_escaped[len(prev_escaped):]
|
||||
|
||||
if delta_escaped:
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_escaped),
|
||||
)
|
||||
])
|
||||
|
||||
return None
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -158,8 +158,10 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
|
||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
|
||||
|
||||
@ -1105,6 +1107,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_TRTLLM_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
||||
|
||||
# If set, it means we pre-downloaded cubin files and flashinfer will
|
||||
# read the cubin files directly.
|
||||
"VLLM_HAS_FLASHINFER_CUBIN":
|
||||
lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
|
||||
|
||||
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
|
||||
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
|
||||
# vllm cutlass GEMM, marlin GEMM.
|
||||
@ -1150,6 +1157,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
|
||||
|
||||
# Whether to use pytorch symmetric memory for allreduce
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
||||
|
||||
# Allows vllm to find tuned config under customized folder
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
{
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,114 @@
|
||||
{
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -52,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"HQQMarlinMethod",
|
||||
"QuarkLinearMethod",
|
||||
"ModelOptNvFp4LinearMethod",
|
||||
"PetitNvFp4LinearMethod",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ QuantizationMethods = Literal[
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .rtn import RTNConfig
|
||||
from .torchao import TorchAOConfig
|
||||
@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"rtn": RTNConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
@ -13,7 +13,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
@ -28,8 +29,10 @@ logger = init_logger(__name__)
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
def __init__(self,
|
||||
unquantized_modules: Optional[list[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.unquantized_modules = unquantized_modules or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
||||
return any(module_name in prefix for module_name in unquantized_modules)
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
|
||||
306
vllm/model_executor/layers/quantization/petit.py
Normal file
306
vllm/model_executor/layers/quantization/petit.py
Normal file
@ -0,0 +1,306 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.petit_utils import (
|
||||
apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit,
|
||||
verify_petit_nvfp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Configuration class to support the NVFP4 quantized model
|
||||
# generated by the ModelOpt quantization tool
|
||||
class PetitNvFp4Config(QuantizationConfig):
|
||||
"""Config class for Petit FP4."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool = False,
|
||||
kv_cache_quant_algo: Optional[str] = None,
|
||||
group_size: Optional[int] = None,
|
||||
exclude_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
self._check_hardware_support()
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning("Detected nvfp4 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change.")
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
def _check_hardware_support(self) -> None:
|
||||
"""
|
||||
Verifies that the current hardware is supported by the Petit backend.
|
||||
This backend is specifically designed for AMD GPUs and is not
|
||||
supported on the CUDA platform.
|
||||
"""
|
||||
# This check ensures the code is NOT running on an NVIDIA GPU.
|
||||
if current_platform.is_cuda():
|
||||
raise ValueError(
|
||||
"The 'petit' quantization backend is designed for AMD GPUs "
|
||||
"and is not supported on the CUDA platform. For NVIDIA GPUs, "
|
||||
"please use a different quantization method such as FP8, AWQ, "
|
||||
"or GPTQ.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "petit_nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Petit supports the gfx90a and gfx942 GPUs
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
|
||||
qc = cls.get_from_keys(config, ["quantization"])
|
||||
|
||||
quant_method_raw = qc.get("quant_algo")
|
||||
if not isinstance(quant_method_raw, str) or not quant_method_raw:
|
||||
raise ValueError(
|
||||
"Missing or invalid 'quant_algo' in quantization config.")
|
||||
quant_method = quant_method_raw.upper()
|
||||
|
||||
group_size_raw = qc.get("group_size")
|
||||
if not isinstance(group_size_raw, int):
|
||||
raise ValueError(
|
||||
"Missing or invalid 'group_size' (int) in hf_quant_config.json."
|
||||
)
|
||||
group_size = group_size_raw
|
||||
|
||||
verify_petit_nvfp4_supported(quant_method, group_size)
|
||||
|
||||
kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
|
||||
if not isinstance(kv_cache_quant_algo_raw, str):
|
||||
raise ValueError(
|
||||
"'kv_cache_quant_algo' must be a string if provided.")
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
|
||||
exclude_raw = qc.get("exclude_modules", [])
|
||||
if exclude_raw is None:
|
||||
exclude_modules: list[str] = []
|
||||
elif isinstance(exclude_raw, list) and all(
|
||||
isinstance(x, str) for x in exclude_raw):
|
||||
exclude_modules = exclude_raw
|
||||
else:
|
||||
raise ValueError(
|
||||
"'exclude_modules' must be a list[str] (or omitted).")
|
||||
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
|
||||
kv_cache_quant_algo=kv_cache_quant_algo,
|
||||
group_size=group_size,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_rocm():
|
||||
return None
|
||||
|
||||
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
|
||||
return cls.get_name() # "petit_nvfp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
|
||||
qc = quant_config.get("quantization", quant_config)
|
||||
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
|
||||
return algo == "NVFP4"
|
||||
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
|
||||
prefix, exclude):
|
||||
return UnquantizedLinearMethod()
|
||||
return PetitNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return PetitFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def require_group_size(self) -> int:
|
||||
if self.group_size is None:
|
||||
logger.warning("group_size not set; defaulting to 16 for NVFP4.")
|
||||
return 16
|
||||
return self.group_size
|
||||
|
||||
def require_kv_cache_quant_algo(self) -> str:
|
||||
return self.kv_cache_quant_algo or "auto"
|
||||
|
||||
def require_exclude_modules(self) -> list[str]:
|
||||
return list(self.exclude_modules or [])
|
||||
|
||||
|
||||
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class PetitNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
|
||||
|Tensor Name | datatype | shape |
|
||||
|----------------------------------------------------|
|
||||
|input_scale | torch.float32 | scalar |
|
||||
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
||||
|weight_scale | FP8-E4M3 | [X, Y] |
|
||||
|weight_scale_2 | torch.float32 | scalar |
|
||||
|
||||
The weights are quantized per block of 16 elements.
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PetitNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
if input_size_per_partition % 16 != 0:
|
||||
raise ValueError("Unsupported model when in features size is "
|
||||
"not multiple of 16")
|
||||
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_nvfp4_serialized
|
||||
else params_dtype)
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
# 2 fp4 data is packed in one uint8 in the input dimension
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale_2", weight_scale_2)
|
||||
|
||||
group_size = self.quant_config.require_group_size()
|
||||
weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // group_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
||||
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
||||
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
||||
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
||||
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
prepare_nvfp4_layer_for_petit(layer)
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_petit_nvfp4_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
122
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
122
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
# TYPE_CHECKING is used for static type analysis to prevent circular imports.
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
# 1. Create a global variable as a placeholder for the module
|
||||
_petit_kernel: Optional["ModuleType"] = None
|
||||
|
||||
_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with "
|
||||
"`pip install petit-kernel`.")
|
||||
|
||||
|
||||
def _import_petit_kernel() -> "ModuleType":
|
||||
"""
|
||||
A helper function to handle the lazy import.
|
||||
The first time this function is called, it will import the petit_kernel
|
||||
library and store it in the global _petit_kernel variable.
|
||||
Subsequent calls will return the already-loaded module directly.
|
||||
"""
|
||||
global _petit_kernel
|
||||
if _petit_kernel is not None:
|
||||
return _petit_kernel
|
||||
|
||||
try:
|
||||
import petit_kernel
|
||||
_petit_kernel = petit_kernel
|
||||
return _petit_kernel
|
||||
except ImportError:
|
||||
# The 'from None' syntax prevents chaining the original ImportError,
|
||||
# making the traceback cleaner.
|
||||
raise ImportError(_PETIT_INSTALL_MSG) from None
|
||||
|
||||
|
||||
# The _require_petit function can now be a simple alias for consistency.
|
||||
_require_petit = _import_petit_kernel
|
||||
|
||||
|
||||
def _check_petit_nvfp4_supported(
|
||||
quant_method: str,
|
||||
group_size: Optional[int]) -> tuple[bool, Optional[str]]:
|
||||
if quant_method != "NVFP4":
|
||||
return (
|
||||
False,
|
||||
("Petit currently only supports: NVFP4 quantizations in sglang. "
|
||||
"Please check the `hf_quant_config.json` file for your model's "
|
||||
"quant configuration."),
|
||||
)
|
||||
if group_size is not None and group_size != 16:
|
||||
return (
|
||||
False,
|
||||
"Petit currently only supports: group_size=16 quantizations.",
|
||||
)
|
||||
return (True, None)
|
||||
|
||||
|
||||
def verify_petit_nvfp4_supported(quant_method: str,
|
||||
group_size: Optional[int]) -> None:
|
||||
supported, error_msg = _check_petit_nvfp4_supported(
|
||||
quant_method, group_size)
|
||||
if not supported:
|
||||
assert error_msg is not None
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
|
||||
# 2. Call _import_petit_kernel() to trigger (or get) the import.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
# Repack weights to petit format
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
qweight = layer.weight.view(torch.int32).contiguous()
|
||||
|
||||
# 3. Call functions through the imported module variable.
|
||||
petit_qweight = petit_kernel.repack_nvfp4(qweight,
|
||||
size_n=part_size_n,
|
||||
size_k=part_size_k)
|
||||
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
|
||||
|
||||
# Permute scales
|
||||
weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
|
||||
def apply_petit_nvfp4_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Trigger (or get) the import here as well.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
# TODO: Use auto-tuning to find the performant solution_id
|
||||
# Call the function via the module variable.
|
||||
output = petit_kernel.mul_nvfp4_a16(
|
||||
a=reshaped_x,
|
||||
b=weight,
|
||||
s=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
solution_id=-1,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@ -2,16 +2,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScaleDesc:
|
||||
"""
|
||||
Class for describing a single quantization scaling factor.
|
||||
dtype: data type of the scale
|
||||
static: static scale if True, dynamic if False
|
||||
group_shape: group shape of the scale
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
group_shape: GroupShape
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'static' if self.static else 'dynamic'},{group_shape}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuantKey:
|
||||
"""
|
||||
Class for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
scale: scale descriptor
|
||||
scale2: second-level scale descriptor
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
scale: ScaleDesc
|
||||
scale2: Optional[ScaleDesc] = None
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
|
||||
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"scale({self.scale}),{scale2_str}"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
|
||||
|
||||
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
|
||||
|
||||
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
|
||||
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
|
||||
kNvfp4Quant = QuantKey(FP4_DTYPE,
|
||||
scale=kNvfp4GroupScale,
|
||||
scale2=kStaticTensorScale)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
|
||||
dtype=out_dtype)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
|
||||
get_gguf_extra_tensor_names, get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = get_gguf_weight_type_map(model_config.model,
|
||||
gguf_weights_map)
|
||||
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type == "F32" and name.endswith(".weight")
|
||||
]
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
|
||||
@ -31,9 +31,7 @@ from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except (ImportError, OSError):
|
||||
# see https://github.com/run-ai/runai-model-streamer/issues/26
|
||||
# OSError will be raised on arm64 platform
|
||||
except ImportError:
|
||||
runai_model_streamer = PlaceholderModule(
|
||||
"runai_model_streamer") # type: ignore[assignment]
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
@ -565,6 +563,18 @@ def get_gguf_extra_tensor_names(
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def get_gguf_weight_type_map(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Return GGUF mapped weight's name and its quant type
|
||||
"""
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
return {
|
||||
gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
|
||||
for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map
|
||||
}
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(**rotary_kwargs)
|
||||
|
||||
self.attn = Attention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.out_proj = RowParallelLinear(input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
|
||||
@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module):
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
||||
if inputs_embeds is not None or encoder_input_ids.numel() > 0:
|
||||
if ((inputs_embeds is not None and inputs_embeds.numel() > 0)
|
||||
or encoder_input_ids.numel() > 0):
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
||||
self.lm_head = BartParallelLMHead(self.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.tie_weights(self.model.shared)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.vocab_size,
|
||||
config.vocab_size)
|
||||
@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
||||
else:
|
||||
if "final_logits_bias" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
||||
if self.config.tie_word_embeddings and ("embed_tokens" in name
|
||||
or "lm_head" in name):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from ..layers.activation import SiluAndMul
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision
|
||||
from .qwen2_vl import (_create_qwen2vl_field_factory,
|
||||
apply_rotary_pos_emb_vision)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -1153,7 +1154,9 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen2vl_field_config(hf_inputs)
|
||||
return _create_qwen2vl_field_factory(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size)(
|
||||
hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
|
||||
@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2Config, Idefics2VisionConfig)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module):
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads_per_partition = self.num_heads // tp_size
|
||||
|
||||
if use_data_parallel:
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
3 * self.q_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config: Idefics2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
self.fc1 = cls_fc1(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.fc2 = cls_fc2(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
config: Idefics2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics2VisionAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.self_attn = Idefics2VisionAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module):
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module):
|
||||
self.layers = nn.ModuleList([
|
||||
Idefics2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: bool = True,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder")
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if len(self.encoder.layers) > config.num_hidden_layers:
|
||||
@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
if self.use_data_parallel:
|
||||
encoder_outputs = run_dp_sharded_vision_model(
|
||||
hidden_states, self.encoder)
|
||||
else:
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def _consolidate_qkv_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
qkv_idx_mappings = {
|
||||
".self_attn.q_proj": 0,
|
||||
".self_attn.k_proj": 1,
|
||||
".self_attn.v_proj": 2,
|
||||
}
|
||||
qkv_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
for weight_name, idx in qkv_idx_mappings.items():
|
||||
if weight_name not in name:
|
||||
continue
|
||||
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
|
||||
if new_name not in qkv_weights:
|
||||
qkv_weights[new_name] = [None] * 3
|
||||
qkv_weights[new_name][idx] = loaded_weight
|
||||
break
|
||||
else:
|
||||
yield name, loaded_weight
|
||||
for key, weight in qkv_weights.items():
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
if self.use_data_parallel:
|
||||
weights = self._consolidate_qkv_weights(weights)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# skip pooling header
|
||||
if name.startswith("head."):
|
||||
@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
if weight_name not in name or self.use_data_parallel:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
|
||||
@ -31,6 +31,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
self.attn = Attention(
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
|
||||
@ -24,7 +24,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
MiniCPMVDummyInputsBuilder,
|
||||
@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
audio_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class MiniCPMOAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bns: Batch size * number of audios * number of slices
|
||||
- bn: Batch size * number of audios
|
||||
- c: Number of channels
|
||||
- l: Length
|
||||
- s: Number of slices
|
||||
"""
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
|
||||
audio_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bns", "c", "l", dynamic_dims={"l"}),
|
||||
]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
which is the same as image.
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
which is the same as image. Padding is used therefore `audio_features` is
|
||||
`torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||
audio_feature_lens: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "s"),
|
||||
]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios, num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- s: Number of slices
|
||||
- h: Hidden size (must match language model backbone)
|
||||
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "s", "h", dynamic_dims={"s"}),
|
||||
]
|
||||
|
||||
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
@ -778,6 +778,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# and config class
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(vllm_config=vllm_config,
|
||||
@ -1325,9 +1326,11 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
model = Idefics2VisionTransformer(config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
model = Idefics2VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=self.use_data_parallel)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
return model
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import ModernBertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
dim=self.head_dim,
|
||||
base=rope_theta)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=f"{layer_id}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
per_layer_sliding_window=sliding_window)
|
||||
self.attn = EncoderOnlyAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=f"{layer_id}.attn",
|
||||
per_layer_sliding_window=sliding_window)
|
||||
self.Wo = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias)
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
IMAGE_TOKEN = "<image>"
|
||||
VIDEO_TOKEN = "<video>"
|
||||
@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
visual_vocab_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vit",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
# reserved tokens for INDICATOR_IDS
|
||||
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||
@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
model_type = config.model_type
|
||||
if model_type == "siglip2_navit":
|
||||
return Siglip2NavitModel(config=config, )
|
||||
return Siglip2NavitModel(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
use_data_parallel=use_data_parallel)
|
||||
raise ValueError(
|
||||
f"Unsupported visual tokenizer model_type: {model_type}")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self) -> torch.device:
|
||||
return next(self.head.parameters()).device
|
||||
|
||||
def tokenize(self, logits):
|
||||
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
tokens = torch.softmax(logits, dim=-1,
|
||||
dtype=torch.float32).to(logits.dtype)
|
||||
return tokens
|
||||
|
||||
def encode(self, pixel_values, grid_thws):
|
||||
features = self.vit(pixel_values,
|
||||
grid_thws,
|
||||
output_hidden_states=True,
|
||||
return_dict=True)
|
||||
def encode(self, pixel_values: torch.Tensor,
|
||||
grid_thws: torch.Tensor) -> torch.Tensor:
|
||||
features = self.vit(pixel_values, grid_thws)
|
||||
# refer to qwen2.5-vl patchmerger
|
||||
seq_len, _ = features.shape
|
||||
features = features.reshape(seq_len // (self.config.hidden_stride**2),
|
||||
@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
|
||||
|
||||
return features
|
||||
|
||||
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thws: torch.Tensor) -> torch.Tensor:
|
||||
features = self.encode(pixel_values, grid_thws)
|
||||
logits = self.head(features)
|
||||
tokens = self.tokenize(logits)
|
||||
@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
|
||||
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
|
||||
info=Ovis2_5ProcessingInfo,
|
||||
dummy_inputs=Ovis2_5DummyInputsBuilder)
|
||||
class Ovis2_5(nn.Module, SupportsMultiModal):
|
||||
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
|
||||
text_model_type = self.config.get_text_config().model_type
|
||||
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||
|
||||
# TODO(Isotr0py): PP support
|
||||
# self.make_empty_intermediate_tensors = (
|
||||
# self.language_model.make_empty_intermediate_tensors)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_visual_input(
|
||||
self, is_video,
|
||||
@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
return self.llm
|
||||
|
||||
@ -32,6 +32,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -51,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import is_interleaved
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@ -439,7 +442,7 @@ class Qwen2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -485,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -25,7 +25,7 @@
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
|
||||
torch.empty((0, )))
|
||||
def create_qwen2_5_omni_thinker_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||
MultiModalFieldConfig]]:
|
||||
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,
|
||||
torch.Tensor]):
|
||||
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
|
||||
torch.empty((0, )))
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_pixel_grid_sizes = image_grid_thw.prod(-1)
|
||||
image_embed_grid_sizes = (image_pixel_grid_sizes //
|
||||
spatial_merge_size // spatial_merge_size)
|
||||
|
||||
num_videos = len(video_grid_sizes)
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
|
||||
spatial_merge_size)
|
||||
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_feature_lengths, dim=1),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
num_videos = len(video_grid_sizes)
|
||||
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_feature_lengths, dim=1),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_pixel_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_embed_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_embed_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
use_audio_in_video=MultiModalFieldConfig.shared(
|
||||
"video", num_videos),
|
||||
)
|
||||
|
||||
return _qwen2_5_omni_thinker_field_config
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
|
||||
|
||||
def __init__(self, spatial_merge_size: int, *args, **kwargs):
|
||||
self._spatial_merge_size = spatial_merge_size
|
||||
super().__init__(self._spatial_merge_size, *args, **kwargs)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
|
||||
required_fields={
|
||||
"input_audio_features", "audio_feature_lengths"
|
||||
},
|
||||
fields_factory=_qwen2_5_omni_thinker_field_config,
|
||||
fields_factory=create_qwen2_5_omni_thinker_field_factory(
|
||||
self._spatial_merge_size),
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen2_5OmniThinkerMultiModalDataParser(
|
||||
spatial_merge_size=self.info.get_hf_config(
|
||||
).vision_config.spatial_merge_size,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen2_5_omni_thinker_field_config(hf_inputs)
|
||||
return create_qwen2_5_omni_thinker_field_factory(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size)(
|
||||
hf_inputs)
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
|
||||
@ -699,29 +699,46 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
def _create_qwen2vl_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
]:
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_pixel_grid_sizes = image_grid_thw.prod(-1)
|
||||
image_embed_grid_sizes = (image_pixel_grid_sizes //
|
||||
spatial_merge_size // spatial_merge_size)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
|
||||
spatial_merge_size)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_pixel_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_embed_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_embed_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
return _qwen2vl_field_config
|
||||
|
||||
|
||||
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def __init__(self, spatial_merge_size: int, *args, **kwargs):
|
||||
self._spatial_merge_size = spatial_merge_size
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
@ -731,7 +748,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
data,
|
||||
modality="image",
|
||||
required_fields={"image_embeds", "image_grid_thw"},
|
||||
fields_factory=_qwen2vl_field_config,
|
||||
fields_factory=_create_qwen2vl_field_factory(
|
||||
self._spatial_merge_size),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
@ -745,7 +763,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
data,
|
||||
modality="video",
|
||||
required_fields={"video_embeds", "video_grid_thw"},
|
||||
fields_factory=_qwen2vl_field_config,
|
||||
fields_factory=_create_qwen2vl_field_factory(
|
||||
self._spatial_merge_size),
|
||||
)
|
||||
|
||||
return super()._parse_video_data(data)
|
||||
@ -967,7 +986,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2VLMultiModalDataParser()
|
||||
return Qwen2VLMultiModalDataParser(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -1010,7 +1030,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen2vl_field_config(hf_inputs)
|
||||
return _create_qwen2vl_field_factory(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size)(
|
||||
hf_inputs)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
|
||||
|
||||
@ -130,6 +130,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
||||
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
|
||||
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
|
||||
487
vllm/model_executor/models/seed_oss.py
Normal file
487
vllm/model_executor/models/seed_oss.py
Normal file
@ -0,0 +1,487 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The Seed team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only SeedOss model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig as SeedOssConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SeedOssMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeedOssAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class SeedOssDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SeedOssConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
# By default, SeedOss uses causal attention as it is a
|
||||
# decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = SeedOssAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = SeedOssMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class SeedOssModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
assert config.max_window_layers == config.num_hidden_layers, (
|
||||
"Sliding window for some but all layers is not supported. "
|
||||
"This model uses sliding window but `max_window_layers` = {} "
|
||||
"is less than `num_hidden_layers` = {}. Please open an issue "
|
||||
"to discuss this feature.".format(
|
||||
config.max_window_layers,
|
||||
config.num_hidden_layers,
|
||||
))
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
# Use the provided decoder layer type or default to SeedDecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = SeedOssModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@ -3,16 +3,24 @@
|
||||
"""Implementation of SiglipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
|
||||
|
||||
from vllm.config import QuantizationConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import _Backend
|
||||
|
||||
from .vision import get_vit_attn_backend
|
||||
@ -48,10 +56,11 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
|
||||
# siglip2 naflex
|
||||
if self.num_patches > 0:
|
||||
self.patch_embedding = nn.Linear(
|
||||
in_features=config.num_channels * self.patch_size *
|
||||
self.patch_embedding = ReplicatedLinear(
|
||||
input_size=config.num_channels * self.patch_size *
|
||||
self.patch_size,
|
||||
out_features=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
return_bias=False,
|
||||
)
|
||||
if self.preserve_original_pe:
|
||||
self.position_embedding_size = int(self.num_patches**0.5)
|
||||
@ -89,7 +98,7 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
|
||||
# Apply patch embeddings to already patchified pixel values
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
if isinstance(self.patch_embedding, nn.Linear):
|
||||
if isinstance(self.patch_embedding, LinearBase):
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype))
|
||||
elif isinstance(self.patch_embedding, nn.Conv2d):
|
||||
@ -184,7 +193,13 @@ def apply_rotary_pos_emb(
|
||||
class Siglip2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -199,11 +214,25 @@ class Siglip2Attention(nn.Module):
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
# Detect attention implementation.
|
||||
@ -228,13 +257,15 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
seq_length, embed_dim = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
queries, keys, values = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
queries = queries.view(seq_length, self.num_heads, self.head_dim)
|
||||
keys = keys.view(seq_length, self.num_heads, self.head_dim)
|
||||
values = values.view(seq_length, self.num_heads, self.head_dim)
|
||||
queries = queries.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
keys = keys.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
values = values.view(seq_length, self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
cos, sin = position_embeddings
|
||||
@ -276,41 +307,72 @@ class Siglip2Attention(nn.Module):
|
||||
v_i,
|
||||
dropout_p=0.0)
|
||||
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
|
||||
output_i = output_i.transpose(1, 2).reshape(
|
||||
end_idx - start_idx, -1)
|
||||
outputs.append(output_i)
|
||||
|
||||
attn_output = torch.cat(outputs, dim=0)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Siglip2MLP(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# TODO(Isotr0py): Enable data parallel after we support
|
||||
# disabling TP on parallel linear layer
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Siglip2EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.self_attn = Siglip2Attention(config)
|
||||
self.self_attn = Siglip2Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(config)
|
||||
self.mlp = Siglip2MLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||
@ -347,14 +409,22 @@ class Siglip2Encoder(nn.Module):
|
||||
config: PretrainedConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Siglip2EncoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
Siglip2EncoderLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(
|
||||
config.hidden_size // config.num_attention_heads // 2)
|
||||
@ -445,13 +515,11 @@ class Siglip2Encoder(nn.Module):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
inputs_embeds: torch.Tensor,
|
||||
grid_thws: torch.Tensor,
|
||||
output_hidden_states: bool = False,
|
||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape
|
||||
@ -506,7 +574,6 @@ class Siglip2Encoder(nn.Module):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
encoder_states = () if output_hidden_states else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for index, block in enumerate(self.layers):
|
||||
@ -517,45 +584,40 @@ class Siglip2Encoder(nn.Module):
|
||||
cu_seqlens_tmp = cu_window_seqlens
|
||||
hidden_states = block(hidden_states, cu_seqlens_tmp,
|
||||
position_embeddings)
|
||||
if output_hidden_states:
|
||||
hidden_states_ = hidden_states.reshape(
|
||||
seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
encoder_states += (hidden_states_[reverse_indices, :].reshape(
|
||||
seq_len, -1), )
|
||||
# tokens = self.post_trunk_norm(tokens)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
||||
|
||||
return hidden_states, encoder_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Siglip2VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = Siglip2VisionEmbeddings(config)
|
||||
self.encoder = Siglip2Encoder(config)
|
||||
self.encoder = Siglip2Encoder(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self._use_flash_attention_2 = \
|
||||
(config._attn_implementation == "flash_attention_2")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
grid_thws: torch.LongTensor,
|
||||
output_hidden_states: Optional[bool] = True,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[
|
||||
tuple[torch.Tensor],
|
||||
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
BaseModelOutputWithNoAttention,
|
||||
]:
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the spatial dimensions (height, width)
|
||||
@ -563,45 +625,64 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
"""
|
||||
hidden_states = self.embeddings(pixel_values, grid_thws)
|
||||
|
||||
last_hidden_state, hidden_states = self.encoder(
|
||||
hidden_states, grid_thws, output_hidden_states)
|
||||
last_hidden_state = self.encoder(hidden_states, grid_thws)
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
if not return_dict:
|
||||
output = (last_hidden_state, )
|
||||
output += (hidden_states, ) if output_hidden_states else ()
|
||||
return output
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class Siglip2NavitModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = Siglip2VisionTransformer(config)
|
||||
self.vision_model = Siglip2VisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
grid_thws: torch.LongTensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[
|
||||
tuple[torch.Tensor],
|
||||
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
BaseModelOutputWithNoAttention,
|
||||
]:
|
||||
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.config.output_hidden_states
|
||||
if return_dict is None:
|
||||
return_dict = self.config.use_return_dict
|
||||
|
||||
) -> torch.Tensor:
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
grid_thws=grid_thws,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
num_chunks_per_rank, ...]
|
||||
|
||||
vision_embeddings = vision_model(image_input_per_rank)
|
||||
# Ensure tensor is contiguous before all_gather
|
||||
vision_embeddings = vision_embeddings.contiguous()
|
||||
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
|
||||
dim=0)
|
||||
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
||||
|
||||
@ -171,7 +171,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||
"quark", "ptpc_fp8", "mxfp4"
|
||||
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -49,12 +49,11 @@ def decode_tokens(
|
||||
`skip_special_tokens=None` means to use the backend's default
|
||||
settings.
|
||||
"""
|
||||
decode_method = getattr(tokenizer, "_decode", tokenizer.decode)
|
||||
if skip_special_tokens is not None:
|
||||
return decode_method(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
return tokenizer.decode(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
return decode_method(token_ids)
|
||||
return tokenizer.decode(token_ids)
|
||||
|
||||
|
||||
def encode_tokens(
|
||||
|
||||
@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||
torch.cuda.set_stream = _patched_set_stream
|
||||
|
||||
|
||||
class _StreamPlaceholder:
|
||||
|
||||
def __init__(self):
|
||||
self.synchronize = lambda: None
|
||||
|
||||
|
||||
def current_stream() -> torch.cuda.Stream:
|
||||
"""
|
||||
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
|
||||
@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream:
|
||||
# On ROCm using the default 0 stream in combination with RCCL
|
||||
# is hurting performance. Therefore creating a dedicated stream
|
||||
# per process
|
||||
_current_stream_tls.value = torch.cuda.Stream(
|
||||
) if current_platform.is_rocm() else torch.cuda.current_stream()
|
||||
if current_platform.is_rocm():
|
||||
_current_stream_tls.value = torch.cuda.Stream()
|
||||
elif current_platform.is_cpu():
|
||||
_current_stream_tls.value = _StreamPlaceholder()
|
||||
else:
|
||||
current_stream = current_platform.current_stream
|
||||
if current_stream is not None:
|
||||
_current_stream_tls.value = current_stream()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Fail to set current stream, current platform "
|
||||
"may not support current_stream with torch API")
|
||||
return _current_stream_tls.value
|
||||
|
||||
|
||||
@ -2466,7 +2482,7 @@ class PlaceholderModule(_PlaceholderBase):
|
||||
A placeholder object to use when a module does not exist.
|
||||
|
||||
This enables more informative errors when trying to access attributes
|
||||
of a module that does not exists.
|
||||
of a module that does not exist.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
@ -3093,7 +3109,7 @@ class LazyLoader(types.ModuleType):
|
||||
"""
|
||||
LazyLoader module borrowed from Tensorflow
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
|
||||
with a addition of "module caching".
|
||||
with an addition of "module caching".
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
Modules such as `xgrammar` might do additional side effects, so we
|
||||
|
||||
@ -132,6 +132,11 @@ def has_nvidia_artifactory() -> bool:
|
||||
This checks connectivity to the kernel inference library artifactory
|
||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||
"""
|
||||
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
|
||||
# it's true, we could assume the cubins are available.
|
||||
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Use a short timeout to avoid blocking for too long
|
||||
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user