mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 11:07:03 +08:00
Merge branch 'main' into fix_hang
This commit is contained in:
commit
562107efb1
@ -44,7 +44,6 @@ docker run \
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||
pytest -v -s v1/test_metrics
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
'
|
||||
|
||||
@ -159,10 +159,7 @@ steps:
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/distributed
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
- tests/distributed/test_symm_mem_allreduce.py
|
||||
commands:
|
||||
@ -180,10 +177,10 @@ steps:
|
||||
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -300,12 +297,9 @@ steps:
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/metrics
|
||||
- pytest -v -s v1/test_kv_sharing.py
|
||||
- pytest -v -s v1/test_metrics_reader.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
- pytest -v -s v1/test_request.py
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
@ -465,29 +459,18 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 14min
|
||||
timeout_in_minutes: 25
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Model Executor Test # 7min
|
||||
timeout_in_minutes: 20
|
||||
- label: Model Executor Test # 23min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
@ -906,14 +889,13 @@ steps:
|
||||
- tests/compile/test_wrapper.py
|
||||
- tests/distributed/
|
||||
- tests/entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/distributed
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- tests/v1/shutdown
|
||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
|
||||
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -12,8 +12,6 @@
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/sample @22quinn @houseroad
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm @chaunceyjiang
|
||||
@ -28,11 +26,13 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||
/vllm/v1/spec_decode @benchislett @luccafong
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||
/vllm/v1/spec_decode @benchislett @luccafong
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
@ -54,7 +54,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector @ApostaC
|
||||
/tests/v1/offloading @ApostaC
|
||||
|
||||
|
||||
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@ -274,7 +274,7 @@ pull_request_rules:
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- files~=^tests/tensorizer_loader/
|
||||
- files~=^tests/model_executor/model_loader/tensorizer_loader/
|
||||
actions:
|
||||
assign:
|
||||
users:
|
||||
|
||||
16
csrc/core/batch_invariant.hpp
Normal file
16
csrc/core/batch_invariant.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// vllm_kernel_override_batch_invariant(); returns true
|
||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||
inline bool vllm_kernel_override_batch_invariant() {
|
||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||
const char* val = std::getenv(env_key.c_str());
|
||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,6 +1,7 @@
|
||||
#include "type_convert.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include "quantization/fp8/common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cub_helpers.h"
|
||||
#include "../core/batch_invariant.hpp"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
|
||||
@ -391,18 +391,28 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
pushd flashinfer
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
|
||||
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
|
||||
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Download pre-compiled cubins
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||
fi
|
||||
elif [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
||||
@ -536,7 +546,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3]
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3]>=0.14.0'
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
||||
@ -66,35 +66,12 @@ Further update the model as follows:
|
||||
!!! important
|
||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
!!! note
|
||||
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
You may override this method if additional logic is required for your model when merging embeddings.
|
||||
|
||||
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
||||
|
||||
|
||||
@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
||||
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
|
||||
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
|
||||
|
||||
!!! tip
|
||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||
|
||||
## Offline Inference
|
||||
|
||||
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
||||
|
||||
@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
### 4. **Restrict Domains Access for Media URLs:**
|
||||
|
||||
Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
|
||||
@ -38,11 +38,13 @@ client = OpenAI(
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
headers = {"User-Agent": "vLLM Example Client"}
|
||||
|
||||
|
||||
def encode_base64_content_from_url(content_url: str) -> str:
|
||||
"""Encode a content retrieved from a remote url to base64 format."""
|
||||
|
||||
with requests.get(content_url) as response:
|
||||
with requests.get(content_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
result = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str:
|
||||
|
||||
|
||||
# Text-only inference
|
||||
def run_text_only(model: str) -> None:
|
||||
def run_text_only(model: str, max_completion_tokens: int) -> None:
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "What's the capital of France?"}],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion.choices[0].message.content
|
||||
print("Chat completion output:", result)
|
||||
print("Chat completion output:\n", result)
|
||||
|
||||
|
||||
# Single-image input inference
|
||||
def run_single_image(model: str) -> None:
|
||||
def run_single_image(model: str, max_completion_tokens: int) -> None:
|
||||
## Use image url in the payload
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -79,11 +81,11 @@ def run_single_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from image url:", result)
|
||||
print("Chat completion output from image url:\n", result)
|
||||
|
||||
## Use base64 encoded image in the payload
|
||||
image_base64 = encode_base64_content_from_url(image_url)
|
||||
@ -101,7 +103,7 @@ def run_single_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
@ -109,7 +111,7 @@ def run_single_image(model: str) -> None:
|
||||
|
||||
|
||||
# Multi-image input inference
|
||||
def run_multi_image(model: str) -> None:
|
||||
def run_multi_image(model: str, max_completion_tokens: int) -> None:
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output:", result)
|
||||
print("Chat completion output:\n", result)
|
||||
|
||||
|
||||
# Video input inference
|
||||
def run_video(model: str) -> None:
|
||||
def run_video(model: str, max_completion_tokens: int) -> None:
|
||||
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"
|
||||
video_base64 = encode_base64_content_from_url(video_url)
|
||||
|
||||
@ -157,11 +159,11 @@ def run_video(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from image url:", result)
|
||||
print("Chat completion output from video url:\n", result)
|
||||
|
||||
## Use base64 encoded video in the payload
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
@ -178,15 +180,15 @@ def run_video(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from base64 encoded image:", result)
|
||||
print("Chat completion output from base64 encoded video:\n", result)
|
||||
|
||||
|
||||
# Audio input inference
|
||||
def run_audio(model: str) -> None:
|
||||
def run_audio(model: str, max_completion_tokens: int) -> None:
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
audio_url = AudioAsset("winning_call").url
|
||||
@ -211,11 +213,11 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from input audio:", result)
|
||||
print("Chat completion output from input audio:\n", result)
|
||||
|
||||
# HTTP URL
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -235,11 +237,11 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from audio url:", result)
|
||||
print("Chat completion output from audio url:\n", result)
|
||||
|
||||
# base64 URL
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
@ -259,14 +261,14 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from base64 encoded audio:", result)
|
||||
print("Chat completion output from base64 encoded audio:\n", result)
|
||||
|
||||
|
||||
def run_multi_audio(model: str) -> None:
|
||||
def run_multi_audio(model: str, max_completion_tokens: int) -> None:
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Two different audios to showcase batched inference.
|
||||
@ -300,11 +302,11 @@ def run_multi_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from input audio:", result)
|
||||
print("Chat completion output from input audio:\n", result)
|
||||
|
||||
|
||||
example_function_map = {
|
||||
@ -330,13 +332,20 @@ def parse_args():
|
||||
choices=list(example_function_map.keys()),
|
||||
help="Conversation type with multimodal data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-completion-tokens",
|
||||
"-n",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of tokens to generate for each completion.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
chat_type = args.chat_type
|
||||
model = get_first_model(client)
|
||||
example_function_map[chat_type](model)
|
||||
example_function_map[chat_type](model, args.max_completion_tokens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -43,7 +43,6 @@ tritonclient==2.51.0
|
||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
numpy
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
|
||||
@ -5,8 +5,6 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
|
||||
# Dependencies for AMD GPUs
|
||||
boto3
|
||||
botocore
|
||||
datasets
|
||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
peft
|
||||
@ -15,7 +13,6 @@ tensorizer==2.10.1
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
conch-triton-kernels==1.2.1
|
||||
timm>=1.0.17
|
||||
@ -51,8 +51,7 @@ tritonclient==2.51.0
|
||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
numpy
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
decord==0.6.0
|
||||
|
||||
@ -72,7 +72,9 @@ blobfile==3.0.0
|
||||
bm25s==0.2.13
|
||||
# via mteb
|
||||
boto3==1.35.57
|
||||
# via tensorizer
|
||||
# via
|
||||
# runai-model-streamer-s3
|
||||
# tensorizer
|
||||
botocore==1.35.57
|
||||
# via
|
||||
# boto3
|
||||
@ -925,10 +927,10 @@ rsa==4.9.1
|
||||
# via google-auth
|
||||
rtree==1.4.0
|
||||
# via torchgeo
|
||||
runai-model-streamer==0.11.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer==0.14.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.14.0
|
||||
# via runai-model-streamer
|
||||
s3transfer==0.10.3
|
||||
# via boto3
|
||||
sacrebleu==2.4.3
|
||||
|
||||
5
setup.py
5
setup.py
@ -654,10 +654,7 @@ setup(
|
||||
"bench": ["pandas", "datasets"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": [
|
||||
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
|
||||
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
|
||||
],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
|
||||
@ -45,6 +45,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
@ -240,6 +240,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format,
|
||||
resolve_chat_template_kwargs,
|
||||
resolve_hf_chat_template)
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_kwargs",
|
||||
[
|
||||
(
|
||||
QWEN2VL_MODEL_ID,
|
||||
{
|
||||
"add_vision_id", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
(
|
||||
QWEN3_MODEL_ID,
|
||||
{
|
||||
"enable_thinking", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
|
||||
expected_kwargs):
|
||||
"""checks that chat_template is a dict type for HF models."""
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
tools = ([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}])
|
||||
|
||||
chat_template_kwargs = {
|
||||
# both unused
|
||||
"unsed_kwargs_1": 123,
|
||||
"unsed_kwargs_2": "abc",
|
||||
# should not appear
|
||||
"chat_template": "{% Hello world! %}",
|
||||
# used by tokenizer
|
||||
"continue_final_message": True,
|
||||
"tools": tools,
|
||||
# both used by Qwen2-VL and Qwen3
|
||||
"add_generation_prompt": True,
|
||||
# only used by Qwen2-VL
|
||||
"add_vision_id": True,
|
||||
# only used by Qwen3
|
||||
"enable_thinking": True,
|
||||
}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model,
|
||||
tokenizer=model_info.tokenizer or model,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Build the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
|
||||
|
||||
|
||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
||||
# processor class instead of using `tokenizer_config.json`
|
||||
# yapf: disable
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
}
|
||||
@ -14,6 +14,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.model_loader.tensorizer
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
# yapf: disable
|
||||
@ -27,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer_loader import (
|
||||
# yapf: enable
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from .conftest import DummyExecutor, assert_from_collective_rpc
|
||||
|
||||
try:
|
||||
@ -651,6 +651,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
|
||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.3"),
|
||||
}
|
||||
|
||||
@ -100,10 +100,9 @@ def test_distributed(
|
||||
kwargs_test=kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="bitsandbytes quantization is currently not supported in rocm.")
|
||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
@ -121,6 +120,11 @@ def test_quantization(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if (current_platform.is_rocm()
|
||||
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
|
||||
pytest.skip(
|
||||
"bitsandbytes quantization is currently not supported in rocm.")
|
||||
|
||||
with vllm_runner(
|
||||
model, model_impl="auto", enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
|
||||
@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
|
||||
raw_image_url: str, suffix: str):
|
||||
connector = MediaConnector()
|
||||
connector = MediaConnector(
|
||||
# Domain restriction should not apply to data URLs.
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
url_image = url_images[raw_image_url]
|
||||
|
||||
try:
|
||||
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_allowed_media_domains(video_url: str, num_frames: int):
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={"video": {
|
||||
"num_frames": num_frames,
|
||||
}},
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = connector.fetch_video(disallowed_url)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = await connector.fetch_video_async(disallowed_url)
|
||||
|
||||
0
tests/v1/distributed/__init__.py
Normal file
0
tests/v1/distributed/__init__.py
Normal file
@ -12,7 +12,7 @@ import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
@ -13,7 +13,7 @@ import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
@ -8,7 +8,7 @@ from typing import Any, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
@ -88,69 +88,71 @@ def test_ngram_correctness(
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm",
|
||||
"deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to its " \
|
||||
"head_dim not being a a multiple of 32")),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
@ -174,9 +176,14 @@ def test_eagle_correctness(
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
# because vision encoder has head_dim 88 being incompatible
|
||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||
pass
|
||||
else:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN does not support "
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
|
||||
290
tests/v1/generation/test_batch_invariance.py
Normal file
290
tests/v1/generation/test_batch_invariance.py
Normal file
@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Lightweight random prompt generator to vary prompt lengths and content.
|
||||
vocab = [
|
||||
"alpha",
|
||||
"bravo",
|
||||
"charlie",
|
||||
"delta",
|
||||
"echo",
|
||||
"foxtrot",
|
||||
"golf",
|
||||
"hotel",
|
||||
"india",
|
||||
"juliet",
|
||||
"kilo",
|
||||
"lima",
|
||||
"mike",
|
||||
"november",
|
||||
"oscar",
|
||||
"papa",
|
||||
"quebec",
|
||||
"romeo",
|
||||
"sierra",
|
||||
"tango",
|
||||
"uniform",
|
||||
"victor",
|
||||
"whiskey",
|
||||
"xray",
|
||||
"yankee",
|
||||
"zulu",
|
||||
]
|
||||
n = random.randint(min_words, max_words)
|
||||
words = random.choices(vocab, k=n)
|
||||
|
||||
# Add some noise and punctuation variability
|
||||
if random.random() < 0.5:
|
||||
words[0] = words[0].capitalize()
|
||||
if random.random() < 0.2:
|
||||
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
||||
punct = random.choice([".", "?", "!", "...", ""])
|
||||
return " ".join(words) + punct
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
using the high-level v1 LLM() API only (no manual batching).
|
||||
|
||||
Strategy:
|
||||
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
||||
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
||||
- For many trials, generate a batch (size N) where the needle appears at a
|
||||
random position among random filler prompts using the bs=N engine.
|
||||
- Track how many trials match vs mismatch, and report totals at the end.
|
||||
The test fails if any mismatches occur, but we still dump pass/fail
|
||||
counts.
|
||||
|
||||
Notes:
|
||||
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
seed.
|
||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||
"""
|
||||
random.seed(12345)
|
||||
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
|
||||
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||
|
||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
|
||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
|
||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||
|
||||
# Sampling parameters: longer outputs with a more random-sounding
|
||||
# continuation,but still deterministic due to fixed seed.
|
||||
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
||||
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
||||
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = ("There once was a ")
|
||||
|
||||
llm_bs1 = None
|
||||
llm_bsN = None
|
||||
try:
|
||||
# Engine with bs=1 behavior
|
||||
llm_bs1 = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=1,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
baseline_out = llm_bs1.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
assert len(baseline_out[0].outputs) >= 1
|
||||
baseline_text = baseline_out[0].outputs[0].text
|
||||
|
||||
# Engine with larger batch limit (e.g., 64)
|
||||
llm_bsN = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
mismatches = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
# Create a batch of size `batch_size` and insert the needle at
|
||||
# a random index
|
||||
prompts: list[str] = []
|
||||
needle_pos = random.randint(0, batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(_random_prompt())
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = llm_bsN.generate(prompts, sampling)
|
||||
# Find the needle output by position
|
||||
needle_output = outputs[needle_pos]
|
||||
assert needle_output.prompt == needle_prompt
|
||||
assert len(needle_output.outputs) >= 1
|
||||
text = needle_output.outputs[0].text
|
||||
|
||||
if text != baseline_text:
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, batch_size={batch_size}")
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (batch_size={batch_size}).")
|
||||
|
||||
finally:
|
||||
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
||||
if llm_bs1 is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bs1.shutdown()
|
||||
if llm_bsN is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bsN.shutdown()
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
|
||||
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# Force float32 to avoid precision-induced differences.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enforce_eager=True, # helps reduce nondeterminism from some backends
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
bs1_logprobs_per_prompt = []
|
||||
for p in prompts:
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# BS=2: run prompts in a batch and collect logprobs per step for each
|
||||
# prompt.
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bs2_logprobs_per_prompt = []
|
||||
for o in outs_batched:
|
||||
step_logprobs = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs2_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
||||
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
|
||||
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
||||
f"Different number of generation steps for prompt index {i}: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
||||
assert a.shape == b.shape, (
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: "
|
||||
f"{a.shape} vs {b.shape}")
|
||||
# Bitwise exact equality.
|
||||
assert torch.equal(
|
||||
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape}).")
|
||||
|
||||
|
||||
def LLM_with_max_seqs(
|
||||
model: str,
|
||||
max_num_seqs: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_model_len: int,
|
||||
swap_space: int,
|
||||
) -> LLM:
|
||||
"""
|
||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||
using the high-level v1 LLM API, while constraining memory usage.
|
||||
"""
|
||||
return LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
# Constrain GPU memory pool so test can run even on busy GPUs.
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
# Keep KV cache footprint small while allowing longer outputs.
|
||||
max_model_len=max_model_len,
|
||||
# Allow some CPU offload if needed.
|
||||
swap_space=swap_space,
|
||||
# Keep things lean and CI-friendly.
|
||||
dtype="float16",
|
||||
# Single-GPU by default; override externally if desired.
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||
)
|
||||
@ -1,71 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
|
||||
|
||||
# Prometheus metrics utilities for testing
|
||||
|
||||
63
tests/v1/worker/test_utils.py
Normal file
63
tests/v1/worker/test_utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
63
tools/flashinfer-build.sh
Normal file
63
tools/flashinfer-build.sh
Normal file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env bash
|
||||
# This script is used to build FlashInfer wheels with AOT kernels
|
||||
|
||||
set -ex
|
||||
|
||||
# FlashInfer configuration
|
||||
FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}"
|
||||
CUDA_VERSION="${CUDA_VERSION}"
|
||||
BUILD_WHEEL="${BUILD_WHEEL:-true}"
|
||||
|
||||
if [[ -z "${FLASHINFER_GIT_REF}" ]]; then
|
||||
echo "❌ FLASHINFER_GIT_REF must be specified" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${CUDA_VERSION}" ]]; then
|
||||
echo "❌ CUDA_VERSION must be specified" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}"
|
||||
|
||||
# Clone FlashInfer
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
|
||||
# Set CUDA arch list based on CUDA version
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
|
||||
echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
pushd flashinfer
|
||||
# Make sure the wheel is built for the correct CUDA version
|
||||
export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Build AOT kernels
|
||||
export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
python3 -m flashinfer.aot
|
||||
|
||||
if [[ "${BUILD_WHEEL}" == "true" ]]; then
|
||||
# Build wheel for distribution
|
||||
uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist .
|
||||
echo "✅ FlashInfer wheel built successfully in flashinfer-dist/"
|
||||
else
|
||||
# Install directly (for Dockerfile)
|
||||
uv pip install --system --no-build-isolation --force-reinstall .
|
||||
echo "✅ FlashInfer installed successfully"
|
||||
fi
|
||||
popd
|
||||
|
||||
# Cleanup
|
||||
rm -rf flashinfer
|
||||
@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
):
|
||||
dataset_class = MLPerfDataset
|
||||
args.hf_split = "train"
|
||||
elif (
|
||||
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MMStarDataset
|
||||
args.hf_split = "val"
|
||||
args.hf_subset = None
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -2721,3 +2728,76 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
|
||||
random.shuffle(requests)
|
||||
return requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MMStar Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MMStarDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
|
||||
refer to: https://github.com/sgl-project/SpecForge/pull/106
|
||||
"""
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# If --hf-output-len is not set, use the default output length.
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests: list[SampleRequest] = []
|
||||
|
||||
for ind, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
# Split the question text from options
|
||||
# (keep only the part before "Options:").
|
||||
full_q: str = item.get("question", "")
|
||||
question_text = full_q.split("Options:", 1)[0].strip()
|
||||
|
||||
# Multimodal image content.
|
||||
mm_content = process_image(item["image"])
|
||||
|
||||
# Compute prompt token length (note: this is plain text length
|
||||
# if enable_multimodal_chat is False).
|
||||
prompt_len = len(tokenizer(question_text).input_ids)
|
||||
|
||||
if enable_multimodal_chat:
|
||||
# If multimodal content should be embedded in the chat message,
|
||||
# convert to [{"role":"user","content":[...]}]
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
question_text, mm_content
|
||||
)
|
||||
mm_for_request = None # Already embedded in chat content.
|
||||
else:
|
||||
# Default: prompt is plain text,
|
||||
# image is in mm_content for the bench to assemble.
|
||||
prompt = question_text
|
||||
mm_for_request = mm_content
|
||||
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_for_request,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix, no_oversample
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
@ -137,6 +137,9 @@ class ModelConfig:
|
||||
"""Allowing API requests to read local images or videos from directories
|
||||
specified by the server file system. This is a security risk. Should only
|
||||
be enabled in trusted environments."""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
"""If set, only media URLs that belong to this domain can be used for
|
||||
multi-modal inputs. """
|
||||
revision: Optional[str] = None
|
||||
"""The specific model version to use. It can be a branch name, a tag name,
|
||||
or a commit id. If unspecified, will use the default version."""
|
||||
@ -506,9 +509,14 @@ class ModelConfig:
|
||||
else: # task == "auto"
|
||||
pass
|
||||
else:
|
||||
debug_info = {
|
||||
"architectures": architectures,
|
||||
"is_generative_model": is_generative_model,
|
||||
"is_pooling_model": is_pooling_model,
|
||||
}
|
||||
raise AssertionError("The model should be a generative or "
|
||||
"pooling model when task is set to "
|
||||
f"{self.task!r}.")
|
||||
f"{self.task!r}. Found: {debug_info}")
|
||||
|
||||
self.runner = runner
|
||||
self.convert = convert
|
||||
|
||||
@ -279,6 +279,24 @@ class ParallelConfig:
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("allgather_reducescatter", "naive",
|
||||
"deepep_high_throughput", "deepep_low_latency")
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
and self.data_parallel_size > 1)
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup,
|
||||
has_unfinished: bool) -> bool:
|
||||
|
||||
@ -288,6 +288,8 @@ class SpeculativeConfig:
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool) -> torch.Tensor:
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = (self.world_size
|
||||
if is_sequence_parallel else self.dp_world_size)
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
self.initialized = False
|
||||
|
||||
@ -28,6 +28,8 @@ class Cache:
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
@ -40,6 +42,7 @@ class All2AllManagerBase:
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
@ -60,17 +63,21 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
||||
@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
|
||||
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
SymmMemCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return output_list
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
||||
@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
||||
@ -871,17 +871,24 @@ class GroupCoordinator:
|
||||
model)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
router_logits,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.combine(hidden_states)
|
||||
return self.device_communicator.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -297,6 +297,8 @@ class EngineArgs:
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
allowed_media_domains: Optional[
|
||||
list[str]] = ModelConfig.allowed_media_domains
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
@ -531,6 +533,8 @@ class EngineArgs:
|
||||
**model_kwargs["hf_config_path"])
|
||||
model_group.add_argument("--allowed-local-media-path",
|
||||
**model_kwargs["allowed_local_media_path"])
|
||||
model_group.add_argument("--allowed-media-domains",
|
||||
**model_kwargs["allowed_media_domains"])
|
||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||
model_group.add_argument("--code-revision",
|
||||
**model_kwargs["code_revision"])
|
||||
@ -997,6 +1001,7 @@ class EngineArgs:
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
allowed_media_domains=self.allowed_media_domains,
|
||||
dtype=self.dtype,
|
||||
seed=self.seed,
|
||||
revision=self.revision,
|
||||
|
||||
@ -11,7 +11,12 @@ from pathlib import Path
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import random_uuid, supports_kw
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def allowed_media_domains(self):
|
||||
return self._model_config.allowed_media_domains
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
fn_kw = {
|
||||
k for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template"}
|
||||
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||
return {
|
||||
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||
}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: list[ConversationMessage],
|
||||
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
|
||||
)
|
||||
|
||||
try:
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=hf_chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
|
||||
@ -86,6 +86,8 @@ class LLM:
|
||||
or videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -169,6 +171,7 @@ class LLM:
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: ModelDType = "auto",
|
||||
quantization: Optional[QuantizationMethods] = None,
|
||||
@ -264,6 +267,7 @@ class LLM:
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
allowed_media_domains=allowed_media_domains,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
||||
@ -3,12 +3,14 @@
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import secrets
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization header exists and equals "Bearer {api_key}".
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = {f"Bearer {token}" for token in tokens}
|
||||
self.api_tokens = [
|
||||
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
|
||||
]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and headers.get(
|
||||
"Authorization") not in self.api_tokens:
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return response(scope, receive, send)
|
||||
@ -1696,6 +1716,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.
|
||||
|
||||
@ -103,9 +103,13 @@ class FrontendArgs:
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
|
||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: Optional[str] = None
|
||||
|
||||
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
reasoning_parser: str = "",
|
||||
enable_auto_tools: bool = False,
|
||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.response_role = response_role
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
|
||||
# set up tool use
|
||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if not self.use_harmony:
|
||||
# Common case.
|
||||
request_chat_template = request.chat_template
|
||||
chat_template_kwargs = request.chat_template_kwargs
|
||||
if not self.trust_request_chat_template and (
|
||||
request_chat_template is not None or
|
||||
(chat_template_kwargs and
|
||||
chat_template_kwargs.get("chat_template") is not None)):
|
||||
return self.create_error_response(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template.")
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template=request_chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
|
||||
@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int) -> list[int]:
|
||||
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||
sequence_parallel_size)
|
||||
|
||||
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||
return sp_tokens.tolist()
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int,
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
dp_size = len(num_tokens_across_dp_cpu)
|
||||
|
||||
local_size = [-1] * dp_size
|
||||
for i in range(dp_size):
|
||||
dp_tokens = num_tokens_across_dp_cpu[i]
|
||||
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||
sequence_parallel_size)
|
||||
sp_size = len(sp_tokens)
|
||||
|
||||
local_size = [-1] * sp_size
|
||||
for i in range(sp_size):
|
||||
# Take into account sharding if MoE activation is sequence parallel.
|
||||
local_size[i] = min(max_num_tokens,
|
||||
dp_tokens - (max_num_tokens * chunk_idx))
|
||||
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
num_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
@ -98,6 +113,17 @@ class DPMetadata:
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||
num_tokens_across_sp_cpu = (
|
||||
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
@ -147,10 +173,10 @@ class DPMetadata:
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
@ -167,18 +193,18 @@ class DPMetadata:
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
||||
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp is None:
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
assert (num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
||||
num_tokens_across_dp)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
def chunked_sizes(self, sequence_parallel_size: int,
|
||||
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
@ -192,31 +218,40 @@ class DPMetadata:
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
||||
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
||||
to determine the chunk-wise split.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||
we use SP between the layers to avoid
|
||||
redundant ops. We need this value to
|
||||
compute the chunked sizes.
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
cu_sizes = self.cu_tokens_across_dp_cpu
|
||||
num_tokens_across_dp_cpu = [
|
||||
(cu_sizes[i] -
|
||||
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
||||
for i in range(len(cu_sizes))
|
||||
]
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||
max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
@contextmanager
|
||||
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||
"""
|
||||
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
|
||||
but without any chunking.
|
||||
"""
|
||||
self.local_sizes = _compute_sp_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
|
||||
561
vllm/model_executor/layers/batch_invariant.py
Normal file
561
vllm/model_executor/layers/batch_invariant.py
Normal file
@ -0,0 +1,561 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
|
||||
args: dict[str, Any]) -> dict[str, Any]:
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]")
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||
def matmul_kernel_persistent(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr, #
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, #
|
||||
BLOCK_SIZE_N: tl.constexpr, #
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
NUM_SMS: tl.constexpr, #
|
||||
A_LARGE: tl.constexpr,
|
||||
B_LARGE: tl.constexpr,
|
||||
C_LARGE: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
if A_LARGE:
|
||||
offs_am = offs_am.to(tl.int64)
|
||||
if B_LARGE:
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M),
|
||||
BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
|
||||
BLOCK_SIZE_N)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(
|
||||
tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
if HAS_BIAS:
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N,
|
||||
other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
else:
|
||||
c = accumulator.to(tl.float16)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def matmul_persistent(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
bias: Union[torch.Tensor, None] = None):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||
assert bias is None or bias.dim() == 1, (
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this")
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
dtype = a.dtype
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
def grid(META):
|
||||
return (min(
|
||||
NUM_SMS,
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"])), )
|
||||
|
||||
configs = {
|
||||
torch.bfloat16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float32: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
# print(a.device, b.device, c.device)
|
||||
matmul_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1), #
|
||||
b.stride(0),
|
||||
b.stride(1), #
|
||||
c.stride(0),
|
||||
c.stride(1), #
|
||||
NUM_SMS=NUM_SMS, #
|
||||
A_LARGE=a.numel() > 2**31,
|
||||
B_LARGE=b.numel() > 2**31,
|
||||
C_LARGE=c.numel() > 2**31,
|
||||
HAS_BIAS=bias is not None,
|
||||
**configs[dtype],
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _log_softmax_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute log_softmax along the last dimension of a 2D tensor.
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
# Get the row index for this block
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
|
||||
# Compute base pointers for input and output rows
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Find maximum value in the row for numerical stability
|
||||
max_val = -float("inf")
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||
|
||||
# Update maximum
|
||||
max_val = tl.max(tl.maximum(vals, max_val))
|
||||
|
||||
# Step 2: Compute sum of exp(x - max_val)
|
||||
sum_exp = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
|
||||
# Compute exp(x - max_val) and accumulate
|
||||
exp_vals = tl.exp(vals - max_val)
|
||||
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||
|
||||
# Compute log(sum_exp)
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||
|
||||
# Compute log_softmax
|
||||
output = vals - max_val - log_sum_exp
|
||||
|
||||
# Store results
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Compute log_softmax using Triton kernel.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Dimension along which to compute log_softmax
|
||||
(only -1 or last dim supported)
|
||||
>> Stashed changes
|
||||
Returns:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
if dim != -1 and dim != input.ndim - 1:
|
||||
raise ValueError("This implementation only supports log_softmax along "
|
||||
"the last dimension")
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
# Allocate output tensor
|
||||
output = torch.empty_like(input_2d)
|
||||
|
||||
# Choose block size based on the number of columns
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Launch kernel with one block per row
|
||||
grid = (n_rows, )
|
||||
_log_softmax_kernel[grid](
|
||||
input_2d,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# Reshape output back to original shape
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \
|
||||
+ k_idx * input_stride2
|
||||
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype
|
||||
(or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions")
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input.ndim
|
||||
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = input.dtype
|
||||
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input.dtype != dtype:
|
||||
input = input.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input.shape)
|
||||
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
|
||||
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
assert not _half_to_float, "not implemented"
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def mean_batch_invariant(input,
|
||||
dim,
|
||||
keepdim=False,
|
||||
dtype: Union[torch.dtype, None] = None):
|
||||
assert dtype is None or dtype == torch.float32, \
|
||||
f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||
# during iterative reduction.
|
||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||
|
||||
# Iteratively apply a deterministic mean.
|
||||
for d in sorted_dims:
|
||||
result = mean_dim(result, dim=d, keepdim=True)
|
||||
|
||||
if not keepdim:
|
||||
# Squeeze the reduced dimensions.
|
||||
for d in sorted_dims:
|
||||
result = result.squeeze(d)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
return _batch_invariant_MODE
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_log_softmax",
|
||||
_log_softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_batch_invariant_mode(enabled: bool = True):
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
|
||||
if enabled:
|
||||
enable_batch_invariant_mode()
|
||||
else:
|
||||
disable_batch_invariant_mode()
|
||||
yield
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
_batch_invariant_MODE, _batch_invariant_LIB = old_data
|
||||
|
||||
|
||||
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
||||
|
||||
|
||||
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
|
||||
|
||||
def vllm_kernel_override_batch_invariant():
|
||||
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
val = os.getenv(env_key, "0")
|
||||
try:
|
||||
is_overridden = int(val) != 0
|
||||
except ValueError:
|
||||
is_overridden = False
|
||||
return is_overridden
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
enable_batch_invariant_mode()
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, get_args, overload
|
||||
|
||||
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
||||
with ctx.dp_metadata.chunked_sizes(self.sp_size,
|
||||
moe_dp_chunk_size_per_rank,
|
||||
chunk_idx):
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = ctx.dp_metadata.sp_local_sizes(
|
||||
self.sp_size) if ctx.dp_metadata else nullcontext()
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
with sp_ctx:
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits, self.is_sequence_parallel)
|
||||
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states,
|
||||
self.is_sequence_parallel)
|
||||
|
||||
return states
|
||||
if (not self.is_sequence_parallel and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
states)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
||||
@ -639,6 +639,19 @@ def runai_safetensors_weights_iterator(
|
||||
yield from tensor_iter
|
||||
|
||||
|
||||
def _init_loader(
|
||||
pg: torch.distributed.ProcessGroup,
|
||||
device: torch.device,
|
||||
f_list: list[str],
|
||||
*,
|
||||
nogds: bool = False,
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg, device, nogds=nogds)
|
||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
return loader
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
@ -656,17 +669,31 @@ def fastsafetensors_weights_iterator(
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
]
|
||||
|
||||
nogds = False
|
||||
|
||||
for f_list in tqdm(
|
||||
weight_files_sub_lists,
|
||||
desc="Loading safetensors using Fastsafetensor loader",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg, device)
|
||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
except RuntimeError as e:
|
||||
if "gds" not in str(e):
|
||||
raise
|
||||
|
||||
loader.close()
|
||||
nogds = True
|
||||
logger.warning_once(
|
||||
"GDS not enabled, setting `nogds=True`.\n"
|
||||
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
|
||||
)
|
||||
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||
fb = loader.copy_files_to_device()
|
||||
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
|
||||
@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.config import QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -38,8 +38,7 @@ from .idefics2_vision_model import (
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
is_pp_missing_parameter, maybe_prefix)
|
||||
|
||||
|
||||
class AriaImagePixelInputs(TensorSchema):
|
||||
@ -298,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
Experts (MoE) Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaTextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config, prefix)
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config, prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.mlp = AriaTextMoELayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
@ -605,19 +602,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
multimodal_embeddings = self._process_image_input(image_input)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -628,10 +612,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class AyaVisionImagePixelInputs(TensorSchema):
|
||||
@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input, **kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
|
||||
@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
),
|
||||
})
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
|
||||
Pooler.for_encode(pooler_config),
|
||||
})
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
|
||||
@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.new.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
@ -27,7 +27,7 @@ from .blip import BlipVisionModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.model.vocabulary_mapping.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
|
||||
@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class Cohere2VisionImagePixelInputs(TensorSchema):
|
||||
@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input, **kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
|
||||
@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -32,7 +32,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
# The image token id may be various
|
||||
_IMAGE_TOKEN = "<image>"
|
||||
@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
||||
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
|
||||
|
||||
self.vision = self._init_vision_module(self.vision_config,
|
||||
quant_config,
|
||||
@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
|
||||
@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_id,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
assert input_ids is not None
|
||||
inputs_embeds = self.get_multimodal_embeddings(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
)
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is None:
|
||||
return inputs_embeds
|
||||
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[self.config.im_patch_id])
|
||||
return inputs_embeds
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -29,10 +29,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.mtp_emb_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
|
||||
prefix)
|
||||
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
ErnieMultiTokenPredictorLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
@ -116,6 +110,9 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -160,6 +157,9 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
|
||||
kwargs = self.prepare_attn_masks(
|
||||
input_ids,
|
||||
|
||||
@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
if input_ids is not None:
|
||||
@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
||||
per_layer_inputs)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
# NOTE: this order of processing mm items is important
|
||||
[self.config.image_token_id, self.config.audio_token_id])
|
||||
return inputs_embeds
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
|
||||
|
||||
class Glm4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[Glm4Config] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if (multimodal_embeddings is not None
|
||||
and len(multimodal_embeddings) != 0
|
||||
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id],
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .utils import flatten_bn, isin_list
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TensorSchema):
|
||||
@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, [
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
]),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
|
||||
@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
vllm_config: VllmConfig,
|
||||
layer_idx: int,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swigluoai")
|
||||
activation="swigluoai",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
if self.is_sequence_parallel:
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
x = x[:num_tokens]
|
||||
return x
|
||||
|
||||
|
||||
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.attn = OAIAttention(config,
|
||||
prefix=f"{prefix}.attn",
|
||||
cache_config=cache_config)
|
||||
self.mlp = MLPBlock(config,
|
||||
self.mlp = MLPBlock(vllm_config,
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.config.num_hidden_layers,
|
||||
lambda prefix: TransformerBlock(
|
||||
self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
vllm_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
|
||||
@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .blip2 import Blip2QFormerModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
|
||||
### Audio Input
|
||||
@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration(
|
||||
# Split variable length features into a tuple
|
||||
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: object,
|
||||
@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return []
|
||||
return None
|
||||
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
return audio_features
|
||||
|
||||
@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the merged LLM / audio embeddings."""
|
||||
if multimodal_embeddings is None \
|
||||
or len(multimodal_embeddings) == 0:
|
||||
return self.language_model.get_input_embeddings(input_ids)
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = embed_multimodal(
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
self.config.audio_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
multimodal_embeddings,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration(
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
audio_embeds,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
model_output = self.language_model(input_ids, positions,
|
||||
|
||||
@ -29,12 +29,13 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.granitemoe import GraniteMoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
is_sequence_parallel=False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(hidden_size,
|
||||
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
num_tokens = orig_shape[0]
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteMoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GraniteMoeDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
|
||||
maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
||||
) -> Optional[MultiModalEmbeddings]:
|
||||
) -> MultiModalEmbeddings:
|
||||
|
||||
multimodal_embeddings = list()
|
||||
if kwargs.get("pixel_values_images") is not None:
|
||||
@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=isin_list(
|
||||
input_ids,
|
||||
[self.config.image_token_id, self.config.video_token_id]),
|
||||
)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -52,8 +52,7 @@ from .idefics2_vision_model import (
|
||||
# yapf: enable
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TensorSchema):
|
||||
@ -539,10 +538,7 @@ class Idefics3Model(nn.Module):
|
||||
|
||||
return image_hidden_states
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.text_model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model.text_model(input_ids,
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping, MutableSequence
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, Union, overload, runtime_checkable)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
from .interfaces_base import is_pooling_model
|
||||
from .interfaces_base import VllmModel, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
def get_language_model(self) -> VllmModel:
|
||||
"""
|
||||
Returns the underlying language model used for text generation.
|
||||
|
||||
@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
*,
|
||||
is_multimodal: torch.Tensor,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
get_input_embeddings: Callable[[Tensor], Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns the input embeddings merged from the text embeddings from
|
||||
input_ids and the multimodal embeddings generated from multimodal
|
||||
kwargs.
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
...
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply token embeddings to `input_ids`."""
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(
|
||||
model: Union[type[object], object]) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
"The model (%s) is missing the `get_input_embeddings` method.",
|
||||
model,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
||||
def is_vllm_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
||||
return (_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_forward(model))
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, isin_list, maybe_prefix)
|
||||
|
||||
|
||||
class InternS1MultiModalProjector(nn.Module):
|
||||
@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
isin_list, maybe_prefix)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module):
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user