mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 20:36:57 +08:00
Merge branch 'main' into fix_hang
This commit is contained in:
commit
562107efb1
@ -44,7 +44,6 @@ docker run \
|
|||||||
pytest -v -s v1/structured_output
|
pytest -v -s v1/structured_output
|
||||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||||
|
pytest -v -s v1/test_metrics
|
||||||
pytest -v -s v1/test_serial_utils.py
|
pytest -v -s v1/test_serial_utils.py
|
||||||
pytest -v -s v1/test_utils.py
|
|
||||||
pytest -v -s v1/test_metrics_reader.py
|
|
||||||
'
|
'
|
||||||
|
|||||||
@ -159,10 +159,7 @@ steps:
|
|||||||
- examples/offline_inference/rlhf.py
|
- examples/offline_inference/rlhf.py
|
||||||
- examples/offline_inference/rlhf_colocate.py
|
- examples/offline_inference/rlhf_colocate.py
|
||||||
- tests/examples/offline_inference/data_parallel.py
|
- tests/examples/offline_inference/data_parallel.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/distributed
|
||||||
- tests/v1/test_external_lb_dp.py
|
|
||||||
- tests/v1/test_internal_lb_dp.py
|
|
||||||
- tests/v1/test_hybrid_lb_dp.py
|
|
||||||
- tests/v1/engine/test_engine_core_client.py
|
- tests/v1/engine/test_engine_core_client.py
|
||||||
- tests/distributed/test_symm_mem_allreduce.py
|
- tests/distributed/test_symm_mem_allreduce.py
|
||||||
commands:
|
commands:
|
||||||
@ -180,10 +177,10 @@ steps:
|
|||||||
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||||
- pytest -v -s distributed/test_utils.py
|
- pytest -v -s distributed/test_utils.py
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
@ -300,12 +297,9 @@ steps:
|
|||||||
- pytest -v -s v1/spec_decode
|
- pytest -v -s v1/spec_decode
|
||||||
- pytest -v -s v1/kv_connector/unit
|
- pytest -v -s v1/kv_connector/unit
|
||||||
- pytest -v -s v1/metrics
|
- pytest -v -s v1/metrics
|
||||||
- pytest -v -s v1/test_kv_sharing.py
|
|
||||||
- pytest -v -s v1/test_metrics_reader.py
|
|
||||||
- pytest -v -s v1/test_oracle.py
|
- pytest -v -s v1/test_oracle.py
|
||||||
- pytest -v -s v1/test_request.py
|
- pytest -v -s v1/test_request.py
|
||||||
- pytest -v -s v1/test_serial_utils.py
|
- pytest -v -s v1/test_serial_utils.py
|
||||||
- pytest -v -s v1/test_utils.py
|
|
||||||
# Integration test for streaming correctness (requires special branch).
|
# Integration test for streaming correctness (requires special branch).
|
||||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
@ -465,29 +459,18 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
- label: Tensorizer Test # 14min
|
- label: Model Executor Test # 23min
|
||||||
timeout_in_minutes: 25
|
timeout_in_minutes: 35
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
|
||||||
- vllm/model_executor/model_loader
|
|
||||||
- tests/tensorizer_loader
|
|
||||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
commands:
|
|
||||||
- apt-get update && apt-get install -y curl libsodium23
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- pytest -v -s tensorizer_loader
|
|
||||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
|
|
||||||
- label: Model Executor Test # 7min
|
|
||||||
timeout_in_minutes: 20
|
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor
|
- vllm/model_executor
|
||||||
- tests/model_executor
|
- tests/model_executor
|
||||||
|
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
commands:
|
commands:
|
||||||
- apt-get update && apt-get install -y curl libsodium23
|
- apt-get update && apt-get install -y curl libsodium23
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s model_executor
|
- pytest -v -s model_executor
|
||||||
|
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
|
|
||||||
- label: Benchmarks # 11min
|
- label: Benchmarks # 11min
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 20
|
||||||
@ -906,14 +889,13 @@ steps:
|
|||||||
- tests/compile/test_wrapper.py
|
- tests/compile/test_wrapper.py
|
||||||
- tests/distributed/
|
- tests/distributed/
|
||||||
- tests/entrypoints/llm/test_collective_rpc.py
|
- tests/entrypoints/llm/test_collective_rpc.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/distributed
|
||||||
- tests/v1/test_external_lb_dp.py
|
|
||||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- tests/v1/shutdown
|
- tests/v1/shutdown
|
||||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
|
|||||||
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -12,8 +12,6 @@
|
|||||||
/vllm/model_executor/layers/mamba @tdoublep
|
/vllm/model_executor/layers/mamba @tdoublep
|
||||||
/vllm/model_executor/model_loader @22quinn
|
/vllm/model_executor/model_loader @22quinn
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||||
/vllm/v1/attention @LucasWilkinson
|
|
||||||
/vllm/v1/sample @22quinn @houseroad
|
|
||||||
/vllm/vllm_flash_attn @LucasWilkinson
|
/vllm/vllm_flash_attn @LucasWilkinson
|
||||||
/vllm/lora @jeejeelee
|
/vllm/lora @jeejeelee
|
||||||
/vllm/reasoning @aarnphm @chaunceyjiang
|
/vllm/reasoning @aarnphm @chaunceyjiang
|
||||||
@ -28,11 +26,13 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
|
|
||||||
# vLLM V1
|
# vLLM V1
|
||||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
/vllm/v1/attention @LucasWilkinson
|
||||||
/vllm/v1/spec_decode @benchislett @luccafong
|
|
||||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||||
|
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||||
|
/vllm/v1/spec_decode @benchislett @luccafong
|
||||||
|
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||||
/vllm/v1/offloading @ApostaC
|
/vllm/v1/offloading @ApostaC
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||||
/tests/lora @jeejeelee
|
/tests/lora @jeejeelee
|
||||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||||
/tests/v1/kv_connector @ApostaC
|
/tests/v1/kv_connector @ApostaC
|
||||||
/tests/v1/offloading @ApostaC
|
/tests/v1/offloading @ApostaC
|
||||||
|
|
||||||
|
|||||||
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@ -274,7 +274,7 @@ pull_request_rules:
|
|||||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
- files~=^tests/tensorizer_loader/
|
- files~=^tests/model_executor/model_loader/tensorizer_loader/
|
||||||
actions:
|
actions:
|
||||||
assign:
|
assign:
|
||||||
users:
|
users:
|
||||||
|
|||||||
16
csrc/core/batch_invariant.hpp
Normal file
16
csrc/core/batch_invariant.hpp
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <string>
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// vllm_kernel_override_batch_invariant(); returns true
|
||||||
|
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||||
|
inline bool vllm_kernel_override_batch_invariant() {
|
||||||
|
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||||
|
const char* val = std::getenv(env_key.c_str());
|
||||||
|
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#include "type_convert.cuh"
|
#include "type_convert.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
|
#include "core/batch_invariant.hpp"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
|||||||
wt_ptr % req_alignment_bytes == 0;
|
wt_ptr % req_alignment_bytes == 0;
|
||||||
bool offsets_are_multiple_of_vector_width =
|
bool offsets_are_multiple_of_vector_width =
|
||||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||||
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_POLY_NORM(8);
|
LAUNCH_FUSED_POLY_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_POLY_NORM(0);
|
LAUNCH_FUSED_POLY_NORM(0);
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
#include "quantization/fp8/common.cuh"
|
#include "quantization/fp8/common.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
|
#include "core/batch_invariant.hpp"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
|
|||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned =
|
bool ptrs_are_aligned =
|
||||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||||
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
|||||||
@ -21,6 +21,7 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include "../cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "../cub_helpers.h"
|
#include "../cub_helpers.h"
|
||||||
|
#include "../core/batch_invariant.hpp"
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||||
|
|
||||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||||
|
|||||||
@ -391,18 +391,28 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
|||||||
git clone --depth 1 --recursive --shallow-submodules \
|
git clone --depth 1 --recursive --shallow-submodules \
|
||||||
--branch ${FLASHINFER_GIT_REF} \
|
--branch ${FLASHINFER_GIT_REF} \
|
||||||
${FLASHINFER_GIT_REPO} flashinfer
|
${FLASHINFER_GIT_REPO} flashinfer
|
||||||
|
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||||
|
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||||
|
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||||
|
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||||
|
else
|
||||||
|
# CUDA 12.8+ supports 10.0a and 12.0
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||||
|
fi
|
||||||
pushd flashinfer
|
pushd flashinfer
|
||||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
|
||||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
|
||||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \
|
||||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
# Download pre-compiled cubins
|
||||||
else
|
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||||
# CUDA 12.8+ supports 10.0a and 12.0
|
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
|
||||||
fi
|
fi
|
||||||
|
elif [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||||
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
||||||
@ -536,7 +546,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
else \
|
else \
|
||||||
BITSANDBYTES_VERSION="0.46.1"; \
|
BITSANDBYTES_VERSION="0.46.1"; \
|
||||||
fi; \
|
fi; \
|
||||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3]
|
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3]>=0.14.0'
|
||||||
|
|
||||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||||
|
|
||||||
|
|||||||
@ -66,35 +66,12 @@ Further update the model as follows:
|
|||||||
!!! important
|
!!! important
|
||||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||||
|
|
||||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
!!! note
|
||||||
|
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||||
|
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||||
|
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||||
|
|
||||||
??? code
|
You may override this method if additional logic is required for your model when merging embeddings.
|
||||||
|
|
||||||
```python
|
|
||||||
from .utils import merge_multimodal_embeddings
|
|
||||||
|
|
||||||
class YourModelForImage2Seq(nn.Module):
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
# `get_input_embeddings` should already be implemented for the language
|
|
||||||
# model as one of the requirements of basic vLLM model implementation.
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
if multimodal_embeddings is not None:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
|
||||||
placeholder_token_id=self.config.image_token_index)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
```
|
|
||||||
|
|
||||||
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
|||||||
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
|
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
|
||||||
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
|
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||||
|
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||||
|
|
||||||
## Offline Inference
|
## Offline Inference
|
||||||
|
|
||||||
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
||||||
|
|||||||
@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
|
|||||||
- Implement proper authentication and authorization for management interfaces
|
- Implement proper authentication and authorization for management interfaces
|
||||||
- Follow the principle of least privilege for all system components
|
- Follow the principle of least privilege for all system components
|
||||||
|
|
||||||
|
### 4. **Restrict Domains Access for Media URLs:**
|
||||||
|
|
||||||
|
Restrict domains that vLLM can access for media URLs by setting
|
||||||
|
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||||
|
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||||
|
|
||||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||||
|
|
||||||
While vLLM is designed to allow unsafe network services to be isolated to
|
While vLLM is designed to allow unsafe network services to be isolated to
|
||||||
|
|||||||
@ -38,11 +38,13 @@ client = OpenAI(
|
|||||||
base_url=openai_api_base,
|
base_url=openai_api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = {"User-Agent": "vLLM Example Client"}
|
||||||
|
|
||||||
|
|
||||||
def encode_base64_content_from_url(content_url: str) -> str:
|
def encode_base64_content_from_url(content_url: str) -> str:
|
||||||
"""Encode a content retrieved from a remote url to base64 format."""
|
"""Encode a content retrieved from a remote url to base64 format."""
|
||||||
|
|
||||||
with requests.get(content_url) as response:
|
with requests.get(content_url, headers=headers) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = base64.b64encode(response.content).decode("utf-8")
|
result = base64.b64encode(response.content).decode("utf-8")
|
||||||
|
|
||||||
@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# Text-only inference
|
# Text-only inference
|
||||||
def run_text_only(model: str) -> None:
|
def run_text_only(model: str, max_completion_tokens: int) -> None:
|
||||||
chat_completion = client.chat.completions.create(
|
chat_completion = client.chat.completions.create(
|
||||||
messages=[{"role": "user", "content": "What's the capital of France?"}],
|
messages=[{"role": "user", "content": "What's the capital of France?"}],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion.choices[0].message.content
|
result = chat_completion.choices[0].message.content
|
||||||
print("Chat completion output:", result)
|
print("Chat completion output:\n", result)
|
||||||
|
|
||||||
|
|
||||||
# Single-image input inference
|
# Single-image input inference
|
||||||
def run_single_image(model: str) -> None:
|
def run_single_image(model: str, max_completion_tokens: int) -> None:
|
||||||
## Use image url in the payload
|
## Use image url in the payload
|
||||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
chat_completion_from_url = client.chat.completions.create(
|
chat_completion_from_url = client.chat.completions.create(
|
||||||
@ -79,11 +81,11 @@ def run_single_image(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_url.choices[0].message.content
|
result = chat_completion_from_url.choices[0].message.content
|
||||||
print("Chat completion output from image url:", result)
|
print("Chat completion output from image url:\n", result)
|
||||||
|
|
||||||
## Use base64 encoded image in the payload
|
## Use base64 encoded image in the payload
|
||||||
image_base64 = encode_base64_content_from_url(image_url)
|
image_base64 = encode_base64_content_from_url(image_url)
|
||||||
@ -101,7 +103,7 @@ def run_single_image(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_base64.choices[0].message.content
|
result = chat_completion_from_base64.choices[0].message.content
|
||||||
@ -109,7 +111,7 @@ def run_single_image(model: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
# Multi-image input inference
|
# Multi-image input inference
|
||||||
def run_multi_image(model: str) -> None:
|
def run_multi_image(model: str, max_completion_tokens: int) -> None:
|
||||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||||
chat_completion_from_url = client.chat.completions.create(
|
chat_completion_from_url = client.chat.completions.create(
|
||||||
@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_url.choices[0].message.content
|
result = chat_completion_from_url.choices[0].message.content
|
||||||
print("Chat completion output:", result)
|
print("Chat completion output:\n", result)
|
||||||
|
|
||||||
|
|
||||||
# Video input inference
|
# Video input inference
|
||||||
def run_video(model: str) -> None:
|
def run_video(model: str, max_completion_tokens: int) -> None:
|
||||||
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"
|
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"
|
||||||
video_base64 = encode_base64_content_from_url(video_url)
|
video_base64 = encode_base64_content_from_url(video_url)
|
||||||
|
|
||||||
@ -157,11 +159,11 @@ def run_video(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_url.choices[0].message.content
|
result = chat_completion_from_url.choices[0].message.content
|
||||||
print("Chat completion output from image url:", result)
|
print("Chat completion output from video url:\n", result)
|
||||||
|
|
||||||
## Use base64 encoded video in the payload
|
## Use base64 encoded video in the payload
|
||||||
chat_completion_from_base64 = client.chat.completions.create(
|
chat_completion_from_base64 = client.chat.completions.create(
|
||||||
@ -178,15 +180,15 @@ def run_video(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_base64.choices[0].message.content
|
result = chat_completion_from_base64.choices[0].message.content
|
||||||
print("Chat completion output from base64 encoded image:", result)
|
print("Chat completion output from base64 encoded video:\n", result)
|
||||||
|
|
||||||
|
|
||||||
# Audio input inference
|
# Audio input inference
|
||||||
def run_audio(model: str) -> None:
|
def run_audio(model: str, max_completion_tokens: int) -> None:
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
audio_url = AudioAsset("winning_call").url
|
audio_url = AudioAsset("winning_call").url
|
||||||
@ -211,11 +213,11 @@ def run_audio(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_base64.choices[0].message.content
|
result = chat_completion_from_base64.choices[0].message.content
|
||||||
print("Chat completion output from input audio:", result)
|
print("Chat completion output from input audio:\n", result)
|
||||||
|
|
||||||
# HTTP URL
|
# HTTP URL
|
||||||
chat_completion_from_url = client.chat.completions.create(
|
chat_completion_from_url = client.chat.completions.create(
|
||||||
@ -235,11 +237,11 @@ def run_audio(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_url.choices[0].message.content
|
result = chat_completion_from_url.choices[0].message.content
|
||||||
print("Chat completion output from audio url:", result)
|
print("Chat completion output from audio url:\n", result)
|
||||||
|
|
||||||
# base64 URL
|
# base64 URL
|
||||||
chat_completion_from_base64 = client.chat.completions.create(
|
chat_completion_from_base64 = client.chat.completions.create(
|
||||||
@ -259,14 +261,14 @@ def run_audio(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_base64.choices[0].message.content
|
result = chat_completion_from_base64.choices[0].message.content
|
||||||
print("Chat completion output from base64 encoded audio:", result)
|
print("Chat completion output from base64 encoded audio:\n", result)
|
||||||
|
|
||||||
|
|
||||||
def run_multi_audio(model: str) -> None:
|
def run_multi_audio(model: str, max_completion_tokens: int) -> None:
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
# Two different audios to showcase batched inference.
|
# Two different audios to showcase batched inference.
|
||||||
@ -300,11 +302,11 @@ def run_multi_audio(model: str) -> None:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
model=model,
|
model=model,
|
||||||
max_completion_tokens=64,
|
max_completion_tokens=max_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = chat_completion_from_base64.choices[0].message.content
|
result = chat_completion_from_base64.choices[0].message.content
|
||||||
print("Chat completion output from input audio:", result)
|
print("Chat completion output from input audio:\n", result)
|
||||||
|
|
||||||
|
|
||||||
example_function_map = {
|
example_function_map = {
|
||||||
@ -330,13 +332,20 @@ def parse_args():
|
|||||||
choices=list(example_function_map.keys()),
|
choices=list(example_function_map.keys()),
|
||||||
help="Conversation type with multimodal data.",
|
help="Conversation type with multimodal data.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-completion-tokens",
|
||||||
|
"-n",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Maximum number of tokens to generate for each completion.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(args) -> None:
|
def main(args) -> None:
|
||||||
chat_type = args.chat_type
|
chat_type = args.chat_type
|
||||||
model = get_first_model(client)
|
model = get_first_model(client)
|
||||||
example_function_map[chat_type](model)
|
example_function_map[chat_type](model, args.max_completion_tokens)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -43,7 +43,6 @@ tritonclient==2.51.0
|
|||||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||||
numba == 0.61.2; python_version > '3.9'
|
numba == 0.61.2; python_version > '3.9'
|
||||||
numpy
|
numpy
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer[s3]==0.14.0
|
||||||
runai-model-streamer-s3==0.11.0
|
|
||||||
fastsafetensors>=0.1.10
|
fastsafetensors>=0.1.10
|
||||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||||
|
|||||||
@ -5,8 +5,6 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req
|
|||||||
numba == 0.61.2; python_version > '3.9'
|
numba == 0.61.2; python_version > '3.9'
|
||||||
|
|
||||||
# Dependencies for AMD GPUs
|
# Dependencies for AMD GPUs
|
||||||
boto3
|
|
||||||
botocore
|
|
||||||
datasets
|
datasets
|
||||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||||
peft
|
peft
|
||||||
@ -15,7 +13,6 @@ tensorizer==2.10.1
|
|||||||
packaging>=24.2
|
packaging>=24.2
|
||||||
setuptools>=77.0.3,<80.0.0
|
setuptools>=77.0.3,<80.0.0
|
||||||
setuptools-scm>=8
|
setuptools-scm>=8
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer[s3]==0.14.0
|
||||||
runai-model-streamer-s3==0.11.0
|
|
||||||
conch-triton-kernels==1.2.1
|
conch-triton-kernels==1.2.1
|
||||||
timm>=1.0.17
|
timm>=1.0.17
|
||||||
@ -51,8 +51,7 @@ tritonclient==2.51.0
|
|||||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||||
numba == 0.61.2; python_version > '3.9'
|
numba == 0.61.2; python_version > '3.9'
|
||||||
numpy
|
numpy
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer[s3]==0.14.0
|
||||||
runai-model-streamer-s3==0.11.0
|
|
||||||
fastsafetensors>=0.1.10
|
fastsafetensors>=0.1.10
|
||||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||||
decord==0.6.0
|
decord==0.6.0
|
||||||
|
|||||||
@ -72,7 +72,9 @@ blobfile==3.0.0
|
|||||||
bm25s==0.2.13
|
bm25s==0.2.13
|
||||||
# via mteb
|
# via mteb
|
||||||
boto3==1.35.57
|
boto3==1.35.57
|
||||||
# via tensorizer
|
# via
|
||||||
|
# runai-model-streamer-s3
|
||||||
|
# tensorizer
|
||||||
botocore==1.35.57
|
botocore==1.35.57
|
||||||
# via
|
# via
|
||||||
# boto3
|
# boto3
|
||||||
@ -925,10 +927,10 @@ rsa==4.9.1
|
|||||||
# via google-auth
|
# via google-auth
|
||||||
rtree==1.4.0
|
rtree==1.4.0
|
||||||
# via torchgeo
|
# via torchgeo
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer==0.14.0
|
||||||
# via -r requirements/test.in
|
|
||||||
runai-model-streamer-s3==0.11.0
|
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
|
runai-model-streamer-s3==0.14.0
|
||||||
|
# via runai-model-streamer
|
||||||
s3transfer==0.10.3
|
s3transfer==0.10.3
|
||||||
# via boto3
|
# via boto3
|
||||||
sacrebleu==2.4.3
|
sacrebleu==2.4.3
|
||||||
|
|||||||
5
setup.py
5
setup.py
@ -654,10 +654,7 @@ setup(
|
|||||||
"bench": ["pandas", "datasets"],
|
"bench": ["pandas", "datasets"],
|
||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": [
|
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||||
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
|
|
||||||
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
|
|
||||||
],
|
|
||||||
"audio": ["librosa", "soundfile",
|
"audio": ["librosa", "soundfile",
|
||||||
"mistral_common[audio]"], # Required for audio processing
|
"mistral_common[audio]"], # Required for audio processing
|
||||||
"video": [], # Kept for backwards compatibility
|
"video": [], # Kept for backwards compatibility
|
||||||
|
|||||||
@ -45,6 +45,7 @@ class MockModelConfig:
|
|||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[str] = None
|
||||||
diff_sampling_param: Optional[dict] = None
|
diff_sampling_param: Optional[dict] = None
|
||||||
allowed_local_media_path: str = ""
|
allowed_local_media_path: str = ""
|
||||||
|
allowed_media_domains: Optional[list[str]] = None
|
||||||
encoder_config = None
|
encoder_config = None
|
||||||
generation_config: str = "auto"
|
generation_config: str = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
|
|||||||
@ -240,6 +240,7 @@ class MockModelConfig:
|
|||||||
logits_processor_pattern = None
|
logits_processor_pattern = None
|
||||||
diff_sampling_param: Optional[dict] = None
|
diff_sampling_param: Optional[dict] = None
|
||||||
allowed_local_media_path: str = ""
|
allowed_local_media_path: str = ""
|
||||||
|
allowed_media_domains: Optional[list[str]] = None
|
||||||
encoder_config = None
|
encoder_config = None
|
||||||
generation_config: str = "auto"
|
generation_config: str = "auto"
|
||||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
|||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
parse_chat_messages_futures,
|
parse_chat_messages_futures,
|
||||||
resolve_chat_template_content_format,
|
resolve_chat_template_content_format,
|
||||||
|
resolve_chat_template_kwargs,
|
||||||
resolve_hf_chat_template)
|
resolve_hf_chat_template)
|
||||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
|||||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||||
|
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
|
||||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
|||||||
assert isinstance(chat_template, str)
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_kwargs",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
QWEN2VL_MODEL_ID,
|
||||||
|
{
|
||||||
|
"add_vision_id", "add_generation_prompt",
|
||||||
|
"continue_final_message", "tools"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
QWEN3_MODEL_ID,
|
||||||
|
{
|
||||||
|
"enable_thinking", "add_generation_prompt",
|
||||||
|
"continue_final_message", "tools"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
|
||||||
|
expected_kwargs):
|
||||||
|
"""checks that chat_template is a dict type for HF models."""
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
tools = ([{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": sample_json_schema,
|
||||||
|
},
|
||||||
|
}])
|
||||||
|
|
||||||
|
chat_template_kwargs = {
|
||||||
|
# both unused
|
||||||
|
"unsed_kwargs_1": 123,
|
||||||
|
"unsed_kwargs_2": "abc",
|
||||||
|
# should not appear
|
||||||
|
"chat_template": "{% Hello world! %}",
|
||||||
|
# used by tokenizer
|
||||||
|
"continue_final_message": True,
|
||||||
|
"tools": tools,
|
||||||
|
# both used by Qwen2-VL and Qwen3
|
||||||
|
"add_generation_prompt": True,
|
||||||
|
# only used by Qwen2-VL
|
||||||
|
"add_vision_id": True,
|
||||||
|
# only used by Qwen3
|
||||||
|
"enable_thinking": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
revision=model_info.revision,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
|
# Build the tokenizer
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test detecting the tokenizer's chat_template
|
||||||
|
chat_template = resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=tools,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
chat_template_kwargs=chat_template_kwargs,
|
||||||
|
)
|
||||||
|
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
||||||
# processor class instead of using `tokenizer_config.json`
|
# processor class instead of using `tokenizer_config.json`
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
|||||||
@ -1,52 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_regex():
|
|
||||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
|
||||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_json_schema():
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"age": {
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"skills": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string",
|
|
||||||
"maxLength": 10
|
|
||||||
},
|
|
||||||
"minItems": 3
|
|
||||||
},
|
|
||||||
"work_history": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"company": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"duration": {
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
"position": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["company", "position"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["name", "age", "skills", "work_history"]
|
|
||||||
}
|
|
||||||
@ -14,6 +14,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.model_loader.tensorizer
|
import vllm.model_executor.model_loader.tensorizer
|
||||||
|
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -27,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer_loader import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.utils import PlaceholderModule
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
|
||||||
from .conftest import DummyExecutor, assert_from_collective_rpc
|
from .conftest import DummyExecutor, assert_from_collective_rpc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -651,6 +651,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||||
|
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
|
||||||
|
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
|
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
|
||||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
min_transformers_version="4.56.3"),
|
min_transformers_version="4.56.3"),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -100,10 +100,9 @@ def test_distributed(
|
|||||||
kwargs_test=kwargs)
|
kwargs_test=kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
current_platform.is_rocm(),
|
|
||||||
reason="bitsandbytes quantization is currently not supported in rocm.")
|
|
||||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
|
||||||
(
|
(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
{
|
{
|
||||||
@ -121,6 +120,11 @@ def test_quantization(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if (current_platform.is_rocm()
|
||||||
|
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
|
||||||
|
pytest.skip(
|
||||||
|
"bitsandbytes quantization is currently not supported in rocm.")
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model, model_impl="auto", enforce_eager=True,
|
model, model_impl="auto", enforce_eager=True,
|
||||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||||
|
|||||||
@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
|
|||||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||||
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
|
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
|
||||||
raw_image_url: str, suffix: str):
|
raw_image_url: str, suffix: str):
|
||||||
connector = MediaConnector()
|
connector = MediaConnector(
|
||||||
|
# Domain restriction should not apply to data URLs.
|
||||||
|
allowed_media_domains=[
|
||||||
|
"www.bogotobogo.com",
|
||||||
|
"github.com",
|
||||||
|
])
|
||||||
url_image = url_images[raw_image_url]
|
url_image = url_images[raw_image_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
|
|||||||
modality_idxs = argsort_mm_positions(mm_positions)
|
modality_idxs = argsort_mm_positions(mm_positions)
|
||||||
|
|
||||||
assert modality_idxs == expected_modality_idxs
|
assert modality_idxs == expected_modality_idxs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
|
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||||
|
async def test_allowed_media_domains(video_url: str, num_frames: int):
|
||||||
|
connector = MediaConnector(
|
||||||
|
media_io_kwargs={"video": {
|
||||||
|
"num_frames": num_frames,
|
||||||
|
}},
|
||||||
|
allowed_media_domains=[
|
||||||
|
"www.bogotobogo.com",
|
||||||
|
"github.com",
|
||||||
|
])
|
||||||
|
|
||||||
|
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||||
|
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||||
|
assert np.array_equal(video_sync, video_async)
|
||||||
|
assert metadata_sync == metadata_async
|
||||||
|
|
||||||
|
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_, _ = connector.fetch_video(disallowed_url)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_, _ = await connector.fetch_video_async(disallowed_url)
|
||||||
|
|||||||
0
tests/v1/distributed/__init__.py
Normal file
0
tests/v1/distributed/__init__.py
Normal file
@ -12,7 +12,7 @@ import pytest_asyncio
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.test_utils import check_request_balancing
|
from tests.v1.utils import check_request_balancing
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||||
@ -13,7 +13,7 @@ import pytest_asyncio
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.test_utils import check_request_balancing
|
from tests.v1.utils import check_request_balancing
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||||
@ -8,7 +8,7 @@ from typing import Any, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.utils import get_attn_backend_list_based_on_platform
|
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||||
from vllm.assets.image import VLM_IMAGES_DIR
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
@ -88,69 +88,71 @@ def test_ngram_correctness(
|
|||||||
Compare the outputs of an original LLM and a speculative LLM
|
Compare the outputs of an original LLM and a speculative LLM
|
||||||
should be the same when using ngram speculative decoding.
|
should be the same when using ngram speculative decoding.
|
||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
test_prompts = get_test_prompts(mm_enabled=False)
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
|
||||||
test_prompts = get_test_prompts(mm_enabled=False)
|
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
spec_llm = LLM(
|
spec_llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "ngram",
|
"method": "ngram",
|
||||||
"prompt_lookup_max": 5,
|
"prompt_lookup_max": 5,
|
||||||
"prompt_lookup_min": 3,
|
"prompt_lookup_min": 3,
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
},
|
},
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
)
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
misses = 0
|
misses = 0
|
||||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
matches += 1
|
matches += 1
|
||||||
else:
|
else:
|
||||||
misses += 1
|
misses += 1
|
||||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches >= int(0.66 * len(ref_outputs))
|
assert matches >= int(0.66 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
@pytest.mark.parametrize(
|
||||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
["model_setup", "mm_enabled"],
|
||||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
[
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
|
||||||
pytest.param(
|
False,
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
marks=pytest.mark.skip(reason="Skipping due to its " \
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
"head_dim not being a a multiple of 32")),
|
||||||
False,
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
pytest.param(
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
True,
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
False,
|
||||||
(("eagle", "eagle618/deepseek-v3-random",
|
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
],
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
ids=[
|
True,
|
||||||
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||||
"llama4_eagle", "llama4_eagle_mm",
|
(("eagle", "eagle618/deepseek-v3-random",
|
||||||
"deepseek_eagle"
|
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||||
])
|
],
|
||||||
|
ids=[
|
||||||
|
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||||
|
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("attn_backend",
|
@pytest.mark.parametrize("attn_backend",
|
||||||
get_attn_backend_list_based_on_platform())
|
get_attn_backend_list_based_on_platform())
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
@ -174,9 +176,14 @@ def test_eagle_correctness(
|
|||||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
# Scout requires default backend selection
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
# because vision encoder has head_dim 88 being incompatible
|
||||||
|
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
pytest.skip("TRITON_ATTN does not support "
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.test_utils import check_request_balancing
|
from tests.v1.utils import check_request_balancing
|
||||||
|
|
||||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||||
|
|
||||||
|
|||||||
290
tests/v1/generation/test_batch_invariance.py
Normal file
290
tests/v1/generation/test_batch_invariance.py
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
|
# Lightweight random prompt generator to vary prompt lengths and content.
|
||||||
|
vocab = [
|
||||||
|
"alpha",
|
||||||
|
"bravo",
|
||||||
|
"charlie",
|
||||||
|
"delta",
|
||||||
|
"echo",
|
||||||
|
"foxtrot",
|
||||||
|
"golf",
|
||||||
|
"hotel",
|
||||||
|
"india",
|
||||||
|
"juliet",
|
||||||
|
"kilo",
|
||||||
|
"lima",
|
||||||
|
"mike",
|
||||||
|
"november",
|
||||||
|
"oscar",
|
||||||
|
"papa",
|
||||||
|
"quebec",
|
||||||
|
"romeo",
|
||||||
|
"sierra",
|
||||||
|
"tango",
|
||||||
|
"uniform",
|
||||||
|
"victor",
|
||||||
|
"whiskey",
|
||||||
|
"xray",
|
||||||
|
"yankee",
|
||||||
|
"zulu",
|
||||||
|
]
|
||||||
|
n = random.randint(min_words, max_words)
|
||||||
|
words = random.choices(vocab, k=n)
|
||||||
|
|
||||||
|
# Add some noise and punctuation variability
|
||||||
|
if random.random() < 0.5:
|
||||||
|
words[0] = words[0].capitalize()
|
||||||
|
if random.random() < 0.2:
|
||||||
|
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
||||||
|
punct = random.choice([".", "?", "!", "...", ""])
|
||||||
|
return " ".join(words) + punct
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(1000)
|
||||||
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||||
|
"""
|
||||||
|
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||||
|
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||||
|
using the high-level v1 LLM() API only (no manual batching).
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
||||||
|
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
||||||
|
- For many trials, generate a batch (size N) where the needle appears at a
|
||||||
|
random position among random filler prompts using the bs=N engine.
|
||||||
|
- Track how many trials match vs mismatch, and report totals at the end.
|
||||||
|
The test fails if any mismatches occur, but we still dump pass/fail
|
||||||
|
counts.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||||
|
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||||
|
to produce a more random-sounding phrase, yet remain deterministic by
|
||||||
|
seed.
|
||||||
|
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||||
|
"""
|
||||||
|
random.seed(12345)
|
||||||
|
|
||||||
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||||
|
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
|
||||||
|
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||||
|
|
||||||
|
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||||
|
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
|
||||||
|
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
|
||||||
|
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||||
|
|
||||||
|
# Sampling parameters: longer outputs with a more random-sounding
|
||||||
|
# continuation,but still deterministic due to fixed seed.
|
||||||
|
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
||||||
|
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
||||||
|
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
|
||||||
|
|
||||||
|
sampling = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
seed=20240919,
|
||||||
|
)
|
||||||
|
|
||||||
|
needle_prompt = ("There once was a ")
|
||||||
|
|
||||||
|
llm_bs1 = None
|
||||||
|
llm_bsN = None
|
||||||
|
try:
|
||||||
|
# Engine with bs=1 behavior
|
||||||
|
llm_bs1 = LLM_with_max_seqs(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=1,
|
||||||
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
swap_space=swap_space_gb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Baseline generation for the needle prompt alone.
|
||||||
|
baseline_out = llm_bs1.generate([needle_prompt], sampling)
|
||||||
|
assert len(baseline_out) == 1
|
||||||
|
assert len(baseline_out[0].outputs) >= 1
|
||||||
|
baseline_text = baseline_out[0].outputs[0].text
|
||||||
|
|
||||||
|
# Engine with larger batch limit (e.g., 64)
|
||||||
|
llm_bsN = LLM_with_max_seqs(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=batch_size,
|
||||||
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
swap_space=swap_space_gb,
|
||||||
|
)
|
||||||
|
|
||||||
|
mismatches = 0
|
||||||
|
|
||||||
|
for trial in range(num_trials):
|
||||||
|
# Create a batch of size `batch_size` and insert the needle at
|
||||||
|
# a random index
|
||||||
|
prompts: list[str] = []
|
||||||
|
needle_pos = random.randint(0, batch_size - 1)
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i == needle_pos:
|
||||||
|
prompts.append(needle_prompt)
|
||||||
|
else:
|
||||||
|
prompts.append(_random_prompt())
|
||||||
|
|
||||||
|
# Generate with the larger-batch engine
|
||||||
|
outputs = llm_bsN.generate(prompts, sampling)
|
||||||
|
# Find the needle output by position
|
||||||
|
needle_output = outputs[needle_pos]
|
||||||
|
assert needle_output.prompt == needle_prompt
|
||||||
|
assert len(needle_output.outputs) >= 1
|
||||||
|
text = needle_output.outputs[0].text
|
||||||
|
|
||||||
|
if text != baseline_text:
|
||||||
|
mismatches += 1
|
||||||
|
|
||||||
|
passes = num_trials - mismatches
|
||||||
|
# Dump how many passed vs failed
|
||||||
|
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||||
|
f"failed={mismatches}, batch_size={batch_size}")
|
||||||
|
|
||||||
|
if mismatches > 0:
|
||||||
|
pytest.fail(
|
||||||
|
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||||
|
f"of {num_trials} trials (batch_size={batch_size}).")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
||||||
|
if llm_bs1 is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
llm_bs1.shutdown()
|
||||||
|
if llm_bsN is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
llm_bsN.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_step_logprobs(request_output):
|
||||||
|
if getattr(request_output, "outputs", None):
|
||||||
|
inner = request_output.outputs[0]
|
||||||
|
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||||
|
t = torch.tensor(
|
||||||
|
[
|
||||||
|
inner.logprobs[i][tid].logprob
|
||||||
|
for i, tid in enumerate(inner.token_ids)
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
return t
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(),
|
||||||
|
reason="Requires CUDA to match production inference path.",
|
||||||
|
)
|
||||||
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||||
|
|
||||||
|
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||||
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
|
# Force float32 to avoid precision-induced differences.
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
enforce_eager=True, # helps reduce nondeterminism from some backends
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"The capital of France is",
|
||||||
|
"The capital of Germany is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sp = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=1.0,
|
||||||
|
max_tokens=8,
|
||||||
|
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||||
|
seed=1234,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# BS=1: run prompts individually and collect logprobs per step.
|
||||||
|
bs1_logprobs_per_prompt = []
|
||||||
|
for p in prompts:
|
||||||
|
outs = llm.generate([p], sp, use_tqdm=False)
|
||||||
|
assert len(outs) == 1
|
||||||
|
step_logprobs = _extract_step_logprobs(outs[0])
|
||||||
|
if step_logprobs is None:
|
||||||
|
pytest.skip("Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test.")
|
||||||
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
|
||||||
|
# BS=2: run prompts in a batch and collect logprobs per step for each
|
||||||
|
# prompt.
|
||||||
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
|
assert len(outs_batched) == len(prompts)
|
||||||
|
bs2_logprobs_per_prompt = []
|
||||||
|
for o in outs_batched:
|
||||||
|
step_logprobs = _extract_step_logprobs(o)
|
||||||
|
if step_logprobs is None:
|
||||||
|
pytest.skip("Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test.")
|
||||||
|
bs2_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
|
||||||
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
||||||
|
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
||||||
|
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
|
||||||
|
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
||||||
|
f"Different number of generation steps for prompt index {i}: "
|
||||||
|
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
|
||||||
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
||||||
|
assert a.shape == b.shape, (
|
||||||
|
f"Logits shape mismatch at prompt {i}, step {t}: "
|
||||||
|
f"{a.shape} vs {b.shape}")
|
||||||
|
# Bitwise exact equality.
|
||||||
|
assert torch.equal(
|
||||||
|
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||||
|
f"(dtype={a.dtype}, shape={a.shape}).")
|
||||||
|
|
||||||
|
|
||||||
|
def LLM_with_max_seqs(
|
||||||
|
model: str,
|
||||||
|
max_num_seqs: int,
|
||||||
|
gpu_memory_utilization: float,
|
||||||
|
max_model_len: int,
|
||||||
|
swap_space: int,
|
||||||
|
) -> LLM:
|
||||||
|
"""
|
||||||
|
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||||
|
using the high-level v1 LLM API, while constraining memory usage.
|
||||||
|
"""
|
||||||
|
return LLM(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
# Constrain GPU memory pool so test can run even on busy GPUs.
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
# Keep KV cache footprint small while allowing longer outputs.
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
# Allow some CPU offload if needed.
|
||||||
|
swap_space=swap_space,
|
||||||
|
# Keep things lean and CI-friendly.
|
||||||
|
dtype="float16",
|
||||||
|
# Single-GPU by default; override externally if desired.
|
||||||
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
|
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||||
|
)
|
||||||
@ -1,71 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import regex as re
|
import regex as re
|
||||||
import requests
|
import requests
|
||||||
import torch
|
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
ctx = {
|
|
||||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
|
||||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
|
||||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
|
||||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = {
|
|
||||||
'layers.0.self_attn': torch.zeros((1, )),
|
|
||||||
'layers.1.self_attn': torch.zeros((1, )),
|
|
||||||
'layers.2.self_attn': torch.zeros((1, )),
|
|
||||||
'layers.3.self_attn': torch.zeros((1, )),
|
|
||||||
}
|
|
||||||
runner_kv_caches: list[torch.Tensor] = []
|
|
||||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
||||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
|
||||||
'layers.0.self_attn']
|
|
||||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
|
||||||
'layers.1.self_attn']
|
|
||||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
|
||||||
'layers.2.self_attn']
|
|
||||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
|
||||||
'layers.3.self_attn']
|
|
||||||
|
|
||||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
|
||||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
|
||||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
|
||||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
|
||||||
|
|
||||||
|
|
||||||
def test_bind_kv_cache_non_attention():
|
|
||||||
from vllm.attention import Attention
|
|
||||||
|
|
||||||
# example from Jamba PP=2
|
|
||||||
ctx = {
|
|
||||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
|
||||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
|
||||||
}
|
|
||||||
kv_cache = {
|
|
||||||
'model.layers.20.attn': torch.zeros((1, )),
|
|
||||||
'model.layers.28.attn': torch.zeros((1, )),
|
|
||||||
}
|
|
||||||
|
|
||||||
runner_kv_caches: list[torch.Tensor] = []
|
|
||||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
||||||
|
|
||||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
|
||||||
'model.layers.20.attn']
|
|
||||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
|
||||||
'model.layers.28.attn']
|
|
||||||
|
|
||||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
|
||||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
|
||||||
|
|
||||||
|
|
||||||
# Prometheus metrics utilities for testing
|
# Prometheus metrics utilities for testing
|
||||||
|
|
||||||
63
tests/v1/worker/test_utils.py
Normal file
63
tests/v1/worker/test_utils.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def test_bind_kv_cache():
|
||||||
|
from vllm.attention import Attention
|
||||||
|
|
||||||
|
ctx = {
|
||||||
|
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||||
|
}
|
||||||
|
kv_cache = {
|
||||||
|
'layers.0.self_attn': torch.zeros((1, )),
|
||||||
|
'layers.1.self_attn': torch.zeros((1, )),
|
||||||
|
'layers.2.self_attn': torch.zeros((1, )),
|
||||||
|
'layers.3.self_attn': torch.zeros((1, )),
|
||||||
|
}
|
||||||
|
runner_kv_caches: list[torch.Tensor] = []
|
||||||
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||||
|
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||||
|
'layers.0.self_attn']
|
||||||
|
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||||
|
'layers.1.self_attn']
|
||||||
|
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||||
|
'layers.2.self_attn']
|
||||||
|
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||||
|
'layers.3.self_attn']
|
||||||
|
|
||||||
|
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||||
|
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||||
|
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||||
|
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||||
|
|
||||||
|
|
||||||
|
def test_bind_kv_cache_non_attention():
|
||||||
|
from vllm.attention import Attention
|
||||||
|
|
||||||
|
# example from Jamba PP=2
|
||||||
|
ctx = {
|
||||||
|
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||||
|
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||||
|
}
|
||||||
|
kv_cache = {
|
||||||
|
'model.layers.20.attn': torch.zeros((1, )),
|
||||||
|
'model.layers.28.attn': torch.zeros((1, )),
|
||||||
|
}
|
||||||
|
|
||||||
|
runner_kv_caches: list[torch.Tensor] = []
|
||||||
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||||
|
|
||||||
|
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||||
|
'model.layers.20.attn']
|
||||||
|
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||||
|
'model.layers.28.attn']
|
||||||
|
|
||||||
|
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||||
|
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||||
63
tools/flashinfer-build.sh
Normal file
63
tools/flashinfer-build.sh
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# This script is used to build FlashInfer wheels with AOT kernels
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# FlashInfer configuration
|
||||||
|
FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||||
|
FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}"
|
||||||
|
CUDA_VERSION="${CUDA_VERSION}"
|
||||||
|
BUILD_WHEEL="${BUILD_WHEEL:-true}"
|
||||||
|
|
||||||
|
if [[ -z "${FLASHINFER_GIT_REF}" ]]; then
|
||||||
|
echo "❌ FLASHINFER_GIT_REF must be specified" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z "${CUDA_VERSION}" ]]; then
|
||||||
|
echo "❌ CUDA_VERSION must be specified" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}"
|
||||||
|
|
||||||
|
# Clone FlashInfer
|
||||||
|
git clone --depth 1 --recursive --shallow-submodules \
|
||||||
|
--branch ${FLASHINFER_GIT_REF} \
|
||||||
|
${FLASHINFER_GIT_REPO} flashinfer
|
||||||
|
|
||||||
|
# Set CUDA arch list based on CUDA version
|
||||||
|
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||||
|
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||||
|
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||||
|
else
|
||||||
|
# CUDA 12.8+ supports 10.0a and 12.0
|
||||||
|
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||||
|
|
||||||
|
pushd flashinfer
|
||||||
|
# Make sure the wheel is built for the correct CUDA version
|
||||||
|
export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||||
|
|
||||||
|
# Build AOT kernels
|
||||||
|
export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||||
|
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||||
|
python3 -m flashinfer.aot
|
||||||
|
|
||||||
|
if [[ "${BUILD_WHEEL}" == "true" ]]; then
|
||||||
|
# Build wheel for distribution
|
||||||
|
uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist .
|
||||||
|
echo "✅ FlashInfer wheel built successfully in flashinfer-dist/"
|
||||||
|
else
|
||||||
|
# Install directly (for Dockerfile)
|
||||||
|
uv pip install --system --no-build-isolation --force-reinstall .
|
||||||
|
echo "✅ FlashInfer installed successfully"
|
||||||
|
fi
|
||||||
|
popd
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
rm -rf flashinfer
|
||||||
@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
):
|
):
|
||||||
dataset_class = MLPerfDataset
|
dataset_class = MLPerfDataset
|
||||||
args.hf_split = "train"
|
args.hf_split = "train"
|
||||||
|
elif (
|
||||||
|
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||||
|
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||||
|
):
|
||||||
|
dataset_class = MMStarDataset
|
||||||
|
args.hf_split = "val"
|
||||||
|
args.hf_subset = None
|
||||||
else:
|
else:
|
||||||
supported_datasets = set([
|
supported_datasets = set([
|
||||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
@ -2721,3 +2728,76 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
random.shuffle(requests)
|
random.shuffle(requests)
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# MMStar Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class MMStarDataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
|
||||||
|
refer to: https://github.com/sgl-project/SpecForge/pull/106
|
||||||
|
"""
|
||||||
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
|
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
|
||||||
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
enable_multimodal_chat: bool = False,
|
||||||
|
request_id_prefix: str = "",
|
||||||
|
no_oversample: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
|
# If --hf-output-len is not set, use the default output length.
|
||||||
|
output_len = (output_len
|
||||||
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
|
sampled_requests: list[SampleRequest] = []
|
||||||
|
|
||||||
|
for ind, item in enumerate(self.data):
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
# Split the question text from options
|
||||||
|
# (keep only the part before "Options:").
|
||||||
|
full_q: str = item.get("question", "")
|
||||||
|
question_text = full_q.split("Options:", 1)[0].strip()
|
||||||
|
|
||||||
|
# Multimodal image content.
|
||||||
|
mm_content = process_image(item["image"])
|
||||||
|
|
||||||
|
# Compute prompt token length (note: this is plain text length
|
||||||
|
# if enable_multimodal_chat is False).
|
||||||
|
prompt_len = len(tokenizer(question_text).input_ids)
|
||||||
|
|
||||||
|
if enable_multimodal_chat:
|
||||||
|
# If multimodal content should be embedded in the chat message,
|
||||||
|
# convert to [{"role":"user","content":[...]}]
|
||||||
|
prompt = self.apply_multimodal_chat_transformation(
|
||||||
|
question_text, mm_content
|
||||||
|
)
|
||||||
|
mm_for_request = None # Already embedded in chat content.
|
||||||
|
else:
|
||||||
|
# Default: prompt is plain text,
|
||||||
|
# image is in mm_content for the bench to assemble.
|
||||||
|
prompt = question_text
|
||||||
|
mm_for_request = mm_content
|
||||||
|
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
multi_modal_data=mm_for_request,
|
||||||
|
request_id=request_id_prefix + str(ind),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.maybe_oversample_requests(
|
||||||
|
sampled_requests, num_requests, request_id_prefix, no_oversample
|
||||||
|
)
|
||||||
|
return sampled_requests
|
||||||
|
|||||||
@ -137,6 +137,9 @@ class ModelConfig:
|
|||||||
"""Allowing API requests to read local images or videos from directories
|
"""Allowing API requests to read local images or videos from directories
|
||||||
specified by the server file system. This is a security risk. Should only
|
specified by the server file system. This is a security risk. Should only
|
||||||
be enabled in trusted environments."""
|
be enabled in trusted environments."""
|
||||||
|
allowed_media_domains: Optional[list[str]] = None
|
||||||
|
"""If set, only media URLs that belong to this domain can be used for
|
||||||
|
multi-modal inputs. """
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
"""The specific model version to use. It can be a branch name, a tag name,
|
"""The specific model version to use. It can be a branch name, a tag name,
|
||||||
or a commit id. If unspecified, will use the default version."""
|
or a commit id. If unspecified, will use the default version."""
|
||||||
@ -506,9 +509,14 @@ class ModelConfig:
|
|||||||
else: # task == "auto"
|
else: # task == "auto"
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
debug_info = {
|
||||||
|
"architectures": architectures,
|
||||||
|
"is_generative_model": is_generative_model,
|
||||||
|
"is_pooling_model": is_pooling_model,
|
||||||
|
}
|
||||||
raise AssertionError("The model should be a generative or "
|
raise AssertionError("The model should be a generative or "
|
||||||
"pooling model when task is set to "
|
"pooling model when task is set to "
|
||||||
f"{self.task!r}.")
|
f"{self.task!r}. Found: {debug_info}")
|
||||||
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.convert = convert
|
self.convert = convert
|
||||||
|
|||||||
@ -279,6 +279,24 @@ class ParallelConfig:
|
|||||||
assert last_exc is not None
|
assert last_exc is not None
|
||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
|
# The all_reduce at the end of attention (during o_proj) means that
|
||||||
|
# inputs are replicated across each rank of the tensor parallel group.
|
||||||
|
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||||
|
# tokens results in useless duplicate computation and communication.
|
||||||
|
#
|
||||||
|
# In this case, ensure the input to the experts is sequence parallel
|
||||||
|
# to avoid the excess work.
|
||||||
|
#
|
||||||
|
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||||
|
@property
|
||||||
|
def use_sequence_parallel_moe(self) -> bool:
|
||||||
|
return (envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
in ("allgather_reducescatter", "naive",
|
||||||
|
"deepep_high_throughput", "deepep_low_latency")
|
||||||
|
and self.enable_expert_parallel
|
||||||
|
and self.tensor_parallel_size > 1
|
||||||
|
and self.data_parallel_size > 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_unfinished_dp(dp_group: ProcessGroup,
|
def has_unfinished_dp(dp_group: ProcessGroup,
|
||||||
has_unfinished: bool) -> bool:
|
has_unfinished: bool) -> bool:
|
||||||
|
|||||||
@ -288,6 +288,8 @@ class SpeculativeConfig:
|
|||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
allowed_local_media_path=self.target_model_config.
|
allowed_local_media_path=self.target_model_config.
|
||||||
allowed_local_media_path,
|
allowed_local_media_path,
|
||||||
|
allowed_media_domains=self.target_model_config.
|
||||||
|
allowed_media_domains,
|
||||||
dtype=self.target_model_config.dtype,
|
dtype=self.target_model_config.dtype,
|
||||||
seed=self.target_model_config.seed,
|
seed=self.target_model_config.seed,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed import get_dp_group
|
from vllm.distributed import get_dp_group, get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import has_deep_ep, has_pplx
|
from vllm.utils import has_deep_ep, has_pplx
|
||||||
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
|||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
|
||||||
def naive_multicast(self, x: torch.Tensor,
|
def naive_multicast(self, x: torch.Tensor,
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool) -> torch.Tensor:
|
||||||
assert (len(x.shape) == 2)
|
assert (len(x.shape) == 2)
|
||||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=x.dtype)
|
dtype=x.dtype)
|
||||||
|
|
||||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||||
self.dp_rank - 1]
|
world_size = (self.world_size
|
||||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
if is_sequence_parallel else self.dp_world_size)
|
||||||
|
|
||||||
|
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||||
|
end = cu_tokens_across_sp_cpu[rank]
|
||||||
buffer[start:end, :].copy_(x)
|
buffer[start:end, :].copy_(x)
|
||||||
for idx in range(self.dp_world_size):
|
for idx in range(world_size):
|
||||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||||
end = cu_tokens_across_dp_cpu[idx]
|
end = cu_tokens_across_sp_cpu[idx]
|
||||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||||
|
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(
|
||||||
router_logits: torch.Tensor):
|
self,
|
||||||
sizes = get_forward_context(
|
hidden_states: torch.Tensor,
|
||||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
router_logits: torch.Tensor,
|
||||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
is_sequence_parallel: bool = False
|
||||||
[hidden_states, router_logits],
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
dim=0,
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||||
sizes=sizes,
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
)
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||||
|
|
||||||
|
hidden_states = self.naive_multicast(hidden_states,
|
||||||
|
cu_tokens_across_sp_cpu,
|
||||||
|
is_sequence_parallel)
|
||||||
|
router_logits = self.naive_multicast(router_logits,
|
||||||
|
cu_tokens_across_sp_cpu,
|
||||||
|
is_sequence_parallel)
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
sizes = get_forward_context(
|
hidden_states: torch.Tensor,
|
||||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
|
||||||
dim=0,
|
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||||
sizes=sizes)
|
|
||||||
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||||
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||||
|
|
||||||
|
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||||
|
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||||
|
|
||||||
|
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||||
|
hidden_states = all_hidden_states[start:end, :]
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
|||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(
|
||||||
router_logits: torch.Tensor):
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Gather hidden_states and router_logits from all dp ranks.
|
Gather hidden_states and router_logits from all dp ranks.
|
||||||
"""
|
"""
|
||||||
sizes = get_forward_context(
|
sizes = get_forward_context(
|
||||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
|
||||||
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||||
|
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||||
|
hidden_states, router_logits = dist_group.all_gatherv(
|
||||||
[hidden_states, router_logits],
|
[hidden_states, router_logits],
|
||||||
dim=0,
|
dim=0,
|
||||||
sizes=sizes,
|
sizes=sizes,
|
||||||
)
|
)
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Reduce-scatter hidden_states across all dp ranks.
|
Reduce-scatter hidden_states across all dp ranks.
|
||||||
"""
|
"""
|
||||||
sizes = get_forward_context(
|
sizes = get_forward_context(
|
||||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
|
||||||
dim=0,
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||||
sizes=sizes)
|
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||||
|
dim=0,
|
||||||
|
sizes=sizes)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
kwargs, pplx.AllToAll.internode
|
kwargs, pplx.AllToAll.internode
|
||||||
if self.internode else pplx.AllToAll.intranode)
|
if self.internode else pplx.AllToAll.intranode)
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(
|
||||||
router_logits: torch.Tensor):
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
|||||||
def get_handle(self, kwargs):
|
def get_handle(self, kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def dispatch(
|
||||||
router_logits: torch.Tensor):
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
|||||||
self.workspace_tensor = None
|
self.workspace_tensor = None
|
||||||
self.prepare_workspace_tensor = None
|
self.prepare_workspace_tensor = None
|
||||||
self.mapping = None
|
self.mapping = None
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|||||||
@ -28,6 +28,8 @@ class Cache:
|
|||||||
|
|
||||||
|
|
||||||
class All2AllManagerBase:
|
class All2AllManagerBase:
|
||||||
|
rank: int
|
||||||
|
world_size: int
|
||||||
|
|
||||||
def __init__(self, cpu_group):
|
def __init__(self, cpu_group):
|
||||||
self.cpu_group = cpu_group
|
self.cpu_group = cpu_group
|
||||||
@ -40,6 +42,7 @@ class All2AllManagerBase:
|
|||||||
# all2all lives in ep group, which is merged from dp and tp group
|
# all2all lives in ep group, which is merged from dp and tp group
|
||||||
self.dp_group = get_dp_group()
|
self.dp_group = get_dp_group()
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
# no self.ep_group since self.ep_group is still in construction
|
# no self.ep_group since self.ep_group is still in construction
|
||||||
# when we create this object
|
# when we create this object
|
||||||
self.dp_rank = self.dp_group.rank_in_group
|
self.dp_rank = self.dp_group.rank_in_group
|
||||||
@ -60,17 +63,21 @@ class All2AllManagerBase:
|
|||||||
# and reuse it for the same config.
|
# and reuse it for the same config.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def dispatch(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_num_sms(self, num_sms: int):
|
def set_num_sms(self, num_sms: int):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def max_sms_used(self) -> Optional[int]:
|
def max_sms_used(self) -> Optional[int]:
|
||||||
return None # None means it could use the whole GPU
|
return None # None means it could use the whole GPU
|
||||||
|
|
||||||
def dispatch(self, hidden_states: torch.Tensor,
|
def combine(self,
|
||||||
router_logits: torch.Tensor):
|
hidden_states: torch.Tensor,
|
||||||
raise NotImplementedError
|
is_sequence_parallel: bool = False):
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
|
|||||||
module.quant_method.init_prepare_finalize(module)
|
module.quant_method.init_prepare_finalize(module)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Dispatch the hidden states and router logits to the appropriate device.
|
Dispatch the hidden states and router logits to the appropriate device.
|
||||||
This is a no-op in the base class.
|
This is a no-op in the base class.
|
||||||
"""
|
"""
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Combine the hidden states and router logits from the appropriate device.
|
Combine the hidden states and router logits from the appropriate device.
|
||||||
This is a no-op in the base class.
|
This is a no-op in the base class.
|
||||||
|
|||||||
@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||||
|
|
||||||
# ep does not use pynccl
|
|
||||||
use_pynccl = "ep" not in unique_name
|
|
||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
self.use_torch_symm_mem = use_torch_symm_mem
|
self.use_torch_symm_mem = use_torch_symm_mem
|
||||||
|
|
||||||
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
SymmMemCommunicator)
|
SymmMemCommunicator)
|
||||||
|
|
||||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||||
if use_pynccl and self.world_size > 1:
|
if self.world_size > 1:
|
||||||
self.pynccl_comm = PyNcclCommunicator(
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.all2all_manager is not None
|
assert self.all2all_manager is not None
|
||||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||||
hidden_states, router_logits)
|
hidden_states, router_logits, is_sequence_parallel)
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
assert self.all2all_manager is not None
|
assert self.all2all_manager is not None
|
||||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||||
|
is_sequence_parallel)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
|||||||
dist.broadcast(input_, src=src, group=self.device_group)
|
dist.broadcast(input_, src=src, group=self.device_group)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.all2all_manager is not None
|
assert self.all2all_manager is not None
|
||||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||||
hidden_states, router_logits)
|
hidden_states, router_logits, is_sequence_parallel)
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
assert self.all2all_manager is not None
|
assert self.all2all_manager is not None
|
||||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||||
|
is_sequence_parallel)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@ -871,17 +871,24 @@ class GroupCoordinator:
|
|||||||
model)
|
model)
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self,
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
is_sequence_parallel: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.device_communicator is not None:
|
if self.device_communicator is not None:
|
||||||
return self.device_communicator.dispatch(hidden_states,
|
return self.device_communicator.dispatch(hidden_states,
|
||||||
router_logits)
|
router_logits,
|
||||||
|
is_sequence_parallel)
|
||||||
else:
|
else:
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(self, hidden_states) -> torch.Tensor:
|
def combine(self,
|
||||||
|
hidden_states,
|
||||||
|
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||||
if self.device_communicator is not None:
|
if self.device_communicator is not None:
|
||||||
return self.device_communicator.combine(hidden_states)
|
return self.device_communicator.combine(hidden_states,
|
||||||
|
is_sequence_parallel)
|
||||||
else:
|
else:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@ -297,6 +297,8 @@ class EngineArgs:
|
|||||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||||
|
allowed_media_domains: Optional[
|
||||||
|
list[str]] = ModelConfig.allowed_media_domains
|
||||||
download_dir: Optional[str] = LoadConfig.download_dir
|
download_dir: Optional[str] = LoadConfig.download_dir
|
||||||
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||||
@ -531,6 +533,8 @@ class EngineArgs:
|
|||||||
**model_kwargs["hf_config_path"])
|
**model_kwargs["hf_config_path"])
|
||||||
model_group.add_argument("--allowed-local-media-path",
|
model_group.add_argument("--allowed-local-media-path",
|
||||||
**model_kwargs["allowed_local_media_path"])
|
**model_kwargs["allowed_local_media_path"])
|
||||||
|
model_group.add_argument("--allowed-media-domains",
|
||||||
|
**model_kwargs["allowed_media_domains"])
|
||||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||||
model_group.add_argument("--code-revision",
|
model_group.add_argument("--code-revision",
|
||||||
**model_kwargs["code_revision"])
|
**model_kwargs["code_revision"])
|
||||||
@ -997,6 +1001,7 @@ class EngineArgs:
|
|||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
allowed_local_media_path=self.allowed_local_media_path,
|
allowed_local_media_path=self.allowed_local_media_path,
|
||||||
|
allowed_media_domains=self.allowed_media_domains,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
|
|||||||
@ -11,7 +11,12 @@ from pathlib import Path
|
|||||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||||
cast)
|
cast)
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import jinja2.ext
|
||||||
|
import jinja2.meta
|
||||||
import jinja2.nodes
|
import jinja2.nodes
|
||||||
|
import jinja2.parser
|
||||||
|
import jinja2.sandbox
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid, supports_kw
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
def allowed_local_media_path(self):
|
def allowed_local_media_path(self):
|
||||||
return self._model_config.allowed_local_media_path
|
return self._model_config.allowed_local_media_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_media_domains(self):
|
||||||
|
return self._model_config.allowed_media_domains
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mm_registry(self):
|
def mm_registry(self):
|
||||||
return MULTIMODAL_REGISTRY
|
return MULTIMODAL_REGISTRY
|
||||||
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
self._connector = MediaConnector(
|
self._connector = MediaConnector(
|
||||||
media_io_kwargs=media_io_kwargs,
|
media_io_kwargs=media_io_kwargs,
|
||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
|
allowed_media_domains=tracker.allowed_media_domains,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_image(
|
def parse_image(
|
||||||
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
self._connector = MediaConnector(
|
self._connector = MediaConnector(
|
||||||
media_io_kwargs=media_io_kwargs,
|
media_io_kwargs=media_io_kwargs,
|
||||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
|
allowed_media_domains=tracker.allowed_media_domains,
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_image(
|
def parse_image(
|
||||||
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
|
|||||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||||
|
|
||||||
|
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||||
|
# only preserve the parse function used to resolve chat template kwargs
|
||||||
|
class AssistantTracker(jinja2.ext.Extension):
|
||||||
|
tags = {"generation"}
|
||||||
|
|
||||||
|
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||||
|
lineno = next(parser.stream).lineno
|
||||||
|
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||||
|
call = self.call_method("_generation_support")
|
||||||
|
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||||
|
return call_block.set_lineno(lineno)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template_kwargs(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
chat_template: str,
|
||||||
|
chat_template_kwargs: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
fn_kw = {
|
||||||
|
k for k in chat_template_kwargs
|
||||||
|
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||||
|
)
|
||||||
|
parsed_content = env.parse(chat_template)
|
||||||
|
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||||
|
|
||||||
|
# We exclude chat_template from kwargs here, because
|
||||||
|
# chat template has been already resolved at this stage
|
||||||
|
unexpected_vars = {"chat_template"}
|
||||||
|
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||||
|
return {
|
||||||
|
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def apply_hf_chat_template(
|
def apply_hf_chat_template(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
conversation: list[ConversationMessage],
|
conversation: list[ConversationMessage],
|
||||||
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
resolved_kwargs = resolve_chat_template_kwargs(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chat_template=hf_chat_template,
|
||||||
|
chat_template_kwargs=kwargs,
|
||||||
|
)
|
||||||
return tokenizer.apply_chat_template(
|
return tokenizer.apply_chat_template(
|
||||||
conversation=conversation, # type: ignore[arg-type]
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
tools=tools, # type: ignore[arg-type]
|
tools=tools, # type: ignore[arg-type]
|
||||||
chat_template=hf_chat_template,
|
chat_template=hf_chat_template,
|
||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
**kwargs,
|
**resolved_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# External library exceptions can sometimes occur despite the framework's
|
# External library exceptions can sometimes occur despite the framework's
|
||||||
|
|||||||
@ -86,6 +86,8 @@ class LLM:
|
|||||||
or videos from directories specified by the server file system.
|
or videos from directories specified by the server file system.
|
||||||
This is a security risk. Should only be enabled in trusted
|
This is a security risk. Should only be enabled in trusted
|
||||||
environments.
|
environments.
|
||||||
|
allowed_media_domains: If set, only media URLs that belong to this
|
||||||
|
domain can be used for multi-modal inputs.
|
||||||
tensor_parallel_size: The number of GPUs to use for distributed
|
tensor_parallel_size: The number of GPUs to use for distributed
|
||||||
execution with tensor parallelism.
|
execution with tensor parallelism.
|
||||||
dtype: The data type for the model weights and activations. Currently,
|
dtype: The data type for the model weights and activations. Currently,
|
||||||
@ -169,6 +171,7 @@ class LLM:
|
|||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
allowed_local_media_path: str = "",
|
allowed_local_media_path: str = "",
|
||||||
|
allowed_media_domains: Optional[list[str]] = None,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: ModelDType = "auto",
|
dtype: ModelDType = "auto",
|
||||||
quantization: Optional[QuantizationMethods] = None,
|
quantization: Optional[QuantizationMethods] = None,
|
||||||
@ -264,6 +267,7 @@ class LLM:
|
|||||||
skip_tokenizer_init=skip_tokenizer_init,
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
allowed_local_media_path=allowed_local_media_path,
|
allowed_local_media_path=allowed_local_media_path,
|
||||||
|
allowed_media_domains=allowed_media_domains,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
quantization=quantization,
|
quantization=quantization,
|
||||||
|
|||||||
@ -3,12 +3,14 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import multiprocessing.forkserver as forkserver
|
import multiprocessing.forkserver as forkserver
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
|||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
"""
|
"""
|
||||||
Pure ASGI middleware that authenticates each request by checking
|
Pure ASGI middleware that authenticates each request by checking
|
||||||
if the Authorization header exists and equals "Bearer {api_key}".
|
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
|
|||||||
|
|
||||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||||
self.app = app
|
self.app = app
|
||||||
self.api_tokens = {f"Bearer {token}" for token in tokens}
|
self.api_tokens = [
|
||||||
|
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
def verify_token(self, headers: Headers) -> bool:
|
||||||
|
authorization_header_value = headers.get("Authorization")
|
||||||
|
if not authorization_header_value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
scheme, _, param = authorization_header_value.partition(" ")
|
||||||
|
if scheme.lower() != "bearer":
|
||||||
|
return False
|
||||||
|
|
||||||
|
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||||
|
|
||||||
|
token_match = False
|
||||||
|
for token_hash in self.api_tokens:
|
||||||
|
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||||
|
|
||||||
|
return token_match
|
||||||
|
|
||||||
def __call__(self, scope: Scope, receive: Receive,
|
def __call__(self, scope: Scope, receive: Receive,
|
||||||
send: Send) -> Awaitable[None]:
|
send: Send) -> Awaitable[None]:
|
||||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
|
|||||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||||
headers = Headers(scope=scope)
|
headers = Headers(scope=scope)
|
||||||
# Type narrow to satisfy mypy.
|
# Type narrow to satisfy mypy.
|
||||||
if url_path.startswith("/v1") and headers.get(
|
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||||
"Authorization") not in self.api_tokens:
|
|
||||||
response = JSONResponse(content={"error": "Unauthorized"},
|
response = JSONResponse(content={"error": "Unauthorized"},
|
||||||
status_code=401)
|
status_code=401)
|
||||||
return response(scope, receive, send)
|
return response(scope, receive, send)
|
||||||
@ -1696,6 +1716,7 @@ async def init_app_state(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
|
trust_request_chat_template=args.trust_request_chat_template,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
exclude_tools_when_tool_choice_none=args.
|
exclude_tools_when_tool_choice_none=args.
|
||||||
|
|||||||
@ -103,9 +103,13 @@ class FrontendArgs:
|
|||||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||||
"""The format to render message content within a chat template.
|
"""The format to render message content within a chat template.
|
||||||
|
|
||||||
* "string" will render the content as a string. Example: `"Hello World"`
|
* "string" will render the content as a string. Example: `"Hello World"`
|
||||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
|
* "openai" will render the content as a list of dictionaries, similar to
|
||||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||||
|
trust_request_chat_template: bool = False
|
||||||
|
"""Whether to trust the chat template provided in the request. If False,
|
||||||
|
the server will always use the chat template specified by `--chat-template`
|
||||||
|
or the ones from tokenizer."""
|
||||||
response_role: str = "assistant"
|
response_role: str = "assistant"
|
||||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||||
ssl_keyfile: Optional[str] = None
|
ssl_keyfile: Optional[str] = None
|
||||||
|
|||||||
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
|
trust_request_chat_template: bool = False,
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
reasoning_parser: str = "",
|
reasoning_parser: str = "",
|
||||||
enable_auto_tools: bool = False,
|
enable_auto_tools: bool = False,
|
||||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
|
self.trust_request_chat_template = trust_request_chat_template
|
||||||
self.enable_log_outputs = enable_log_outputs
|
self.enable_log_outputs = enable_log_outputs
|
||||||
|
|
||||||
# set up tool use
|
# set up tool use
|
||||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if not self.use_harmony:
|
if not self.use_harmony:
|
||||||
# Common case.
|
# Common case.
|
||||||
|
request_chat_template = request.chat_template
|
||||||
|
chat_template_kwargs = request.chat_template_kwargs
|
||||||
|
if not self.trust_request_chat_template and (
|
||||||
|
request_chat_template is not None or
|
||||||
|
(chat_template_kwargs and
|
||||||
|
chat_template_kwargs.get("chat_template") is not None)):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Chat template is passed with request, but "
|
||||||
|
"--trust-request-chat-template is not set. "
|
||||||
|
"Refused request with untrusted chat template.")
|
||||||
(
|
(
|
||||||
conversation,
|
conversation,
|
||||||
request_prompts,
|
request_prompts,
|
||||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.messages,
|
request.messages,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
chat_template=request_chat_template or self.chat_template,
|
||||||
chat_template_content_format=self.
|
chat_template_content_format=self.
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
|||||||
@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
|
|||||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||||
|
|
||||||
|
|
||||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||||
|
sequence_parallel_size: int) -> list[int]:
|
||||||
|
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||||
|
sequence_parallel_size)
|
||||||
|
|
||||||
|
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||||
|
return sp_tokens.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||||
|
sequence_parallel_size: int,
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
chunk_idx: int) -> list[int]:
|
chunk_idx: int) -> list[int]:
|
||||||
dp_size = len(num_tokens_across_dp_cpu)
|
|
||||||
|
|
||||||
local_size = [-1] * dp_size
|
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||||
for i in range(dp_size):
|
sequence_parallel_size)
|
||||||
dp_tokens = num_tokens_across_dp_cpu[i]
|
sp_size = len(sp_tokens)
|
||||||
|
|
||||||
|
local_size = [-1] * sp_size
|
||||||
|
for i in range(sp_size):
|
||||||
|
# Take into account sharding if MoE activation is sequence parallel.
|
||||||
local_size[i] = min(max_num_tokens,
|
local_size[i] = min(max_num_tokens,
|
||||||
dp_tokens - (max_num_tokens * chunk_idx))
|
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||||
if local_size[i] <= 0:
|
if local_size[i] <= 0:
|
||||||
local_size[i] = 1 # ensure lockstep even if done
|
local_size[i] = 1 # ensure lockstep even if done
|
||||||
return local_size
|
return local_size
|
||||||
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DPMetadata:
|
class DPMetadata:
|
||||||
max_tokens_across_dp_cpu: torch.Tensor
|
max_tokens_across_dp_cpu: torch.Tensor
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor
|
num_tokens_across_dp_cpu: torch.Tensor
|
||||||
|
|
||||||
|
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||||
local_sizes: Optional[list[int]] = None
|
local_sizes: Optional[list[int]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -98,6 +113,17 @@ class DPMetadata:
|
|||||||
dist.all_reduce(num_tokens_tensor, group=group)
|
dist.all_reduce(num_tokens_tensor, group=group)
|
||||||
return num_tokens_tensor.cpu()
|
return num_tokens_tensor.cpu()
|
||||||
|
|
||||||
|
# Get the cumulative tokens across sequence parallel ranks.
|
||||||
|
# In this case the input to the MoEs will be distributed w.r.t both
|
||||||
|
# DP and TP rank.
|
||||||
|
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||||
|
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||||
|
num_tokens_across_sp_cpu = (
|
||||||
|
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||||
|
num_tokens_across_sp_cpu = (
|
||||||
|
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||||
|
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def should_ubatch_across_dp(
|
def should_ubatch_across_dp(
|
||||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||||
@ -147,10 +173,10 @@ class DPMetadata:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None
|
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||||
) -> "DPMetadata":
|
) -> "DPMetadata":
|
||||||
|
|
||||||
assert parallel_config.data_parallel_size > 1
|
assert parallel_config.data_parallel_size > 1
|
||||||
@ -167,18 +193,18 @@ class DPMetadata:
|
|||||||
|
|
||||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||||
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
assert (num_tokens_across_dp_cpu is None
|
||||||
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||||
if num_tokens_across_dp is None:
|
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
if num_tokens_across_dp_cpu is None:
|
||||||
|
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||||
batchsize, dp_size, dp_rank)
|
batchsize, dp_size, dp_rank)
|
||||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
|
||||||
num_tokens_across_dp)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
def chunked_sizes(self, sequence_parallel_size: int,
|
||||||
|
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||||
"""
|
"""
|
||||||
Context manager to compute and temporarily set the per-rank local token
|
Context manager to compute and temporarily set the per-rank local token
|
||||||
sizes for a specific chunk during chunked forward execution.
|
sizes for a specific chunk during chunked forward execution.
|
||||||
@ -192,31 +218,40 @@ class DPMetadata:
|
|||||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||||
of tokens to process in that chunk on each rank.
|
of tokens to process in that chunk on each rank.
|
||||||
|
|
||||||
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
|
||||||
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
|
||||||
to determine the chunk-wise split.
|
|
||||||
|
|
||||||
`self.local_sizes` is only valid inside the context.
|
`self.local_sizes` is only valid inside the context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||||
|
we use SP between the layers to avoid
|
||||||
|
redundant ops. We need this value to
|
||||||
|
compute the chunked sizes.
|
||||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||||
allowed to process in this chunk.
|
allowed to process in this chunk.
|
||||||
chunk_idx: The index of the chunk to compute sizes for.
|
chunk_idx: The index of the chunk to compute sizes for.
|
||||||
"""
|
"""
|
||||||
cu_sizes = self.cu_tokens_across_dp_cpu
|
|
||||||
num_tokens_across_dp_cpu = [
|
|
||||||
(cu_sizes[i] -
|
|
||||||
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
|
||||||
for i in range(len(cu_sizes))
|
|
||||||
]
|
|
||||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||||
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||||
|
max_chunk_size_per_rank, chunk_idx)
|
||||||
|
try:
|
||||||
|
yield self.local_sizes
|
||||||
|
finally:
|
||||||
|
self.local_sizes = None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||||
|
"""
|
||||||
|
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
|
||||||
|
but without any chunking.
|
||||||
|
"""
|
||||||
|
self.local_sizes = _compute_sp_num_tokens(
|
||||||
|
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||||
try:
|
try:
|
||||||
yield self.local_sizes
|
yield self.local_sizes
|
||||||
finally:
|
finally:
|
||||||
self.local_sizes = None
|
self.local_sizes = None
|
||||||
|
|
||||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||||
|
assert self.local_sizes is not None
|
||||||
return self.local_sizes
|
return self.local_sizes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
561
vllm/model_executor/layers/batch_invariant.py
Normal file
561
vllm/model_executor/layers/batch_invariant.py
Normal file
@ -0,0 +1,561 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
|
||||||
|
args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
ret = {}
|
||||||
|
m, n, k = args["M"], args["N"], args["K"]
|
||||||
|
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||||
|
if "tiles_per_update" in args:
|
||||||
|
ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||||
|
f"tiles_per_update={args['tiles_per_update']:02}]")
|
||||||
|
if "c_ptr" in args:
|
||||||
|
bytes_per_elem = args["c_ptr"].element_size()
|
||||||
|
else:
|
||||||
|
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||||
|
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||||
|
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||||
|
group_id = tile_id // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||||
|
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||||
|
return pid_m, pid_n
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||||
|
def matmul_kernel_persistent(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr, #
|
||||||
|
bias_ptr,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K, #
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr, #
|
||||||
|
BLOCK_SIZE_N: tl.constexpr, #
|
||||||
|
BLOCK_SIZE_K: tl.constexpr, #
|
||||||
|
GROUP_SIZE_M: tl.constexpr, #
|
||||||
|
NUM_SMS: tl.constexpr, #
|
||||||
|
A_LARGE: tl.constexpr,
|
||||||
|
B_LARGE: tl.constexpr,
|
||||||
|
C_LARGE: tl.constexpr,
|
||||||
|
HAS_BIAS: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
num_tiles = num_pid_m * num_pid_n
|
||||||
|
|
||||||
|
tile_id_c = start_pid - NUM_SMS
|
||||||
|
|
||||||
|
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
|
||||||
|
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||||
|
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m,
|
||||||
|
GROUP_SIZE_M, NUM_SMS)
|
||||||
|
start_m = pid_m * BLOCK_SIZE_M
|
||||||
|
start_n = pid_n * BLOCK_SIZE_N
|
||||||
|
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
if A_LARGE:
|
||||||
|
offs_am = offs_am.to(tl.int64)
|
||||||
|
if B_LARGE:
|
||||||
|
offs_bn = offs_bn.to(tl.int64)
|
||||||
|
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||||
|
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||||
|
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M),
|
||||||
|
BLOCK_SIZE_M)
|
||||||
|
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
|
||||||
|
BLOCK_SIZE_N)
|
||||||
|
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
for ki in range(k_tiles):
|
||||||
|
if A_LARGE or B_LARGE:
|
||||||
|
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(
|
||||||
|
tl.int64)
|
||||||
|
else:
|
||||||
|
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
|
||||||
|
offs_k[None, :] * stride_ak)
|
||||||
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||||
|
offs_bn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
a = tl.load(a_ptrs,
|
||||||
|
mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K,
|
||||||
|
other=0.0)
|
||||||
|
b = tl.load(b_ptrs,
|
||||||
|
mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K,
|
||||||
|
other=0.0)
|
||||||
|
accumulator = tl.dot(a, b, accumulator)
|
||||||
|
|
||||||
|
tile_id_c += NUM_SMS
|
||||||
|
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m,
|
||||||
|
GROUP_SIZE_M, NUM_SMS)
|
||||||
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
if C_LARGE:
|
||||||
|
offs_cm = offs_cm.to(tl.int64)
|
||||||
|
offs_cn = offs_cn.to(tl.int64)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
|
||||||
|
None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
if HAS_BIAS:
|
||||||
|
bias_ptrs = bias_ptr + offs_cn
|
||||||
|
bias = tl.load(bias_ptrs, mask=offs_cn < N,
|
||||||
|
other=0.0).to(tl.float32)
|
||||||
|
accumulator += bias
|
||||||
|
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||||
|
c = accumulator.to(tl.float8e4nv)
|
||||||
|
else:
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_persistent(a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
bias: Union[torch.Tensor, None] = None):
|
||||||
|
# Check constraints.
|
||||||
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||||
|
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||||
|
assert bias is None or bias.dim() == 1, (
|
||||||
|
"Currently assuming bias is 1D, let Horace know if you run into this")
|
||||||
|
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||||
|
M, K = a.shape
|
||||||
|
K, N = b.shape
|
||||||
|
dtype = a.dtype
|
||||||
|
# Allocates output.
|
||||||
|
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
||||||
|
|
||||||
|
# 1D launch kernel where each block gets its own program.
|
||||||
|
def grid(META):
|
||||||
|
return (min(
|
||||||
|
NUM_SMS,
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||||
|
triton.cdiv(N, META["BLOCK_SIZE_N"])), )
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
torch.bfloat16: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
torch.float16: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 64,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
torch.float32: {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
"num_stages": 3,
|
||||||
|
"num_warps": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# print(a.device, b.device, c.device)
|
||||||
|
matmul_kernel_persistent[grid](
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c, #
|
||||||
|
bias,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K, #
|
||||||
|
a.stride(0),
|
||||||
|
a.stride(1), #
|
||||||
|
b.stride(0),
|
||||||
|
b.stride(1), #
|
||||||
|
c.stride(0),
|
||||||
|
c.stride(1), #
|
||||||
|
NUM_SMS=NUM_SMS, #
|
||||||
|
A_LARGE=a.numel() > 2**31,
|
||||||
|
B_LARGE=b.numel() > 2**31,
|
||||||
|
C_LARGE=c.numel() > 2**31,
|
||||||
|
HAS_BIAS=bias is not None,
|
||||||
|
**configs[dtype],
|
||||||
|
)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _log_softmax_kernel(
|
||||||
|
input_ptr,
|
||||||
|
output_ptr,
|
||||||
|
input_row_stride,
|
||||||
|
output_row_stride,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute log_softmax along the last dimension of a 2D tensor.
|
||||||
|
Each block handles one row of the input tensor.
|
||||||
|
"""
|
||||||
|
# Get the row index for this block
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
|
||||||
|
# Compute base pointers for input and output rows
|
||||||
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||||
|
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||||
|
|
||||||
|
# Step 1: Find maximum value in the row for numerical stability
|
||||||
|
max_val = -float("inf")
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||||
|
|
||||||
|
# Update maximum
|
||||||
|
max_val = tl.max(tl.maximum(vals, max_val))
|
||||||
|
|
||||||
|
# Step 2: Compute sum of exp(x - max_val)
|
||||||
|
sum_exp = 0.0
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
# Compute exp(x - max_val) and accumulate
|
||||||
|
exp_vals = tl.exp(vals - max_val)
|
||||||
|
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||||
|
|
||||||
|
# Compute log(sum_exp)
|
||||||
|
log_sum_exp = tl.log(sum_exp)
|
||||||
|
|
||||||
|
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
# Load values
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||||
|
|
||||||
|
# Compute log_softmax
|
||||||
|
output = vals - max_val - log_sum_exp
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute log_softmax using Triton kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor
|
||||||
|
dim: Dimension along which to compute log_softmax
|
||||||
|
(only -1 or last dim supported)
|
||||||
|
>> Stashed changes
|
||||||
|
Returns:
|
||||||
|
Tensor with log_softmax applied along the specified dimension
|
||||||
|
"""
|
||||||
|
if dim != -1 and dim != input.ndim - 1:
|
||||||
|
raise ValueError("This implementation only supports log_softmax along "
|
||||||
|
"the last dimension")
|
||||||
|
|
||||||
|
# Flatten all dimensions except the last one
|
||||||
|
original_shape = input.shape
|
||||||
|
input_2d = input.reshape(-1, input.shape[-1])
|
||||||
|
input_2d = input_2d.contiguous()
|
||||||
|
|
||||||
|
n_rows, n_cols = input_2d.shape
|
||||||
|
|
||||||
|
# Allocate output tensor
|
||||||
|
output = torch.empty_like(input_2d)
|
||||||
|
|
||||||
|
# Choose block size based on the number of columns
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
# Launch kernel with one block per row
|
||||||
|
grid = (n_rows, )
|
||||||
|
_log_softmax_kernel[grid](
|
||||||
|
input_2d,
|
||||||
|
output,
|
||||||
|
input_2d.stride(0),
|
||||||
|
output.stride(0),
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
# Reshape output back to original shape
|
||||||
|
return output.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def mean_kernel(
|
||||||
|
input_ptr,
|
||||||
|
output_ptr,
|
||||||
|
input_stride0,
|
||||||
|
input_stride1,
|
||||||
|
input_stride2,
|
||||||
|
output_stride0,
|
||||||
|
output_stride1,
|
||||||
|
M, # size before reduction dim
|
||||||
|
N, # size of reduction dim
|
||||||
|
K, # size after reduction dim
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Kernel for computing mean along a single dimension.
|
||||||
|
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||||
|
"""
|
||||||
|
# Program ID gives us which output element we're computing
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
# Compute output indices
|
||||||
|
m_idx = pid // K
|
||||||
|
k_idx = pid % K
|
||||||
|
|
||||||
|
# Bounds check
|
||||||
|
if m_idx >= M or k_idx >= K:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accumulate sum across reduction dimension
|
||||||
|
acc = 0.0
|
||||||
|
for n_start in range(0, N, BLOCK_SIZE):
|
||||||
|
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = n_offsets < N
|
||||||
|
|
||||||
|
# Calculate input indices
|
||||||
|
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \
|
||||||
|
+ k_idx * input_stride2
|
||||||
|
|
||||||
|
# Load and accumulate
|
||||||
|
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||||
|
acc += tl.sum(vals)
|
||||||
|
|
||||||
|
# Compute mean and store
|
||||||
|
mean_val = acc / N
|
||||||
|
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||||
|
tl.store(output_ptr + output_idx, mean_val)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_dim(input: torch.Tensor,
|
||||||
|
dim: int,
|
||||||
|
keepdim: bool = False,
|
||||||
|
dtype: Union[torch.dtype, None] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Triton implementation of torch.mean with single dimension reduction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor
|
||||||
|
dim: Single dimension along which to compute mean
|
||||||
|
keepdim: Whether to keep the reduced dimension
|
||||||
|
dtype: Output dtype. If None, uses input dtype
|
||||||
|
(or float32 for integer inputs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with mean values along specified dimension
|
||||||
|
"""
|
||||||
|
# Validate inputs
|
||||||
|
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||||
|
assert -input.ndim <= dim < input.ndim, (
|
||||||
|
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions")
|
||||||
|
|
||||||
|
# Handle negative dim
|
||||||
|
if dim < 0:
|
||||||
|
dim = dim + input.ndim
|
||||||
|
|
||||||
|
# Handle dtype
|
||||||
|
if dtype is None:
|
||||||
|
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = input.dtype
|
||||||
|
|
||||||
|
# Convert input to appropriate dtype if needed
|
||||||
|
if input.dtype != dtype:
|
||||||
|
input = input.to(dtype)
|
||||||
|
|
||||||
|
# Get input shape and strides
|
||||||
|
shape = list(input.shape)
|
||||||
|
|
||||||
|
# Calculate dimensions for kernel
|
||||||
|
M = 1
|
||||||
|
for i in range(dim):
|
||||||
|
M *= shape[i]
|
||||||
|
|
||||||
|
N = shape[dim]
|
||||||
|
|
||||||
|
K = 1
|
||||||
|
for i in range(dim + 1, len(shape)):
|
||||||
|
K *= shape[i]
|
||||||
|
|
||||||
|
# Reshape input to 3D view (M, N, K)
|
||||||
|
input_3d = input.reshape(M, N, K)
|
||||||
|
|
||||||
|
# Create output shape
|
||||||
|
if keepdim:
|
||||||
|
output_shape = shape.copy()
|
||||||
|
output_shape[dim] = 1
|
||||||
|
else:
|
||||||
|
output_shape = shape[:dim] + shape[dim + 1:]
|
||||||
|
|
||||||
|
# Create output tensor
|
||||||
|
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||||
|
|
||||||
|
# Reshape output for kernel
|
||||||
|
if keepdim:
|
||||||
|
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||||
|
else:
|
||||||
|
output_2d = output.reshape(M, K)
|
||||||
|
|
||||||
|
# Launch kernel
|
||||||
|
grid = (M * K, )
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
mean_kernel[grid](
|
||||||
|
input_3d,
|
||||||
|
output_2d,
|
||||||
|
input_3d.stride(0),
|
||||||
|
input_3d.stride(1),
|
||||||
|
input_3d.stride(2),
|
||||||
|
output_2d.stride(0),
|
||||||
|
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def mm_batch_invariant(a, b):
|
||||||
|
return matmul_persistent(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def addmm_batch_invariant(bias, a, b):
|
||||||
|
return matmul_persistent(a, b, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||||
|
assert not _half_to_float, "not implemented"
|
||||||
|
return log_softmax(input, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_batch_invariant(input,
|
||||||
|
dim,
|
||||||
|
keepdim=False,
|
||||||
|
dtype: Union[torch.dtype, None] = None):
|
||||||
|
assert dtype is None or dtype == torch.float32, \
|
||||||
|
f"unsupported dtype: {dtype}"
|
||||||
|
|
||||||
|
result = input.to(torch.float32)
|
||||||
|
|
||||||
|
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||||
|
# during iterative reduction.
|
||||||
|
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||||
|
|
||||||
|
# Iteratively apply a deterministic mean.
|
||||||
|
for d in sorted_dims:
|
||||||
|
result = mean_dim(result, dim=d, keepdim=True)
|
||||||
|
|
||||||
|
if not keepdim:
|
||||||
|
# Squeeze the reduced dimensions.
|
||||||
|
for d in sorted_dims:
|
||||||
|
result = result.squeeze(d)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_batch_invariant_MODE = False
|
||||||
|
_batch_invariant_LIB = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_batch_invariant_mode_enabled():
|
||||||
|
return _batch_invariant_MODE
|
||||||
|
|
||||||
|
|
||||||
|
def enable_batch_invariant_mode():
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
if _batch_invariant_MODE:
|
||||||
|
return
|
||||||
|
|
||||||
|
_batch_invariant_MODE = True
|
||||||
|
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||||
|
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::_log_softmax",
|
||||||
|
_log_softmax_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||||
|
|
||||||
|
|
||||||
|
def disable_batch_invariant_mode():
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
if _batch_invariant_LIB is not None:
|
||||||
|
_batch_invariant_LIB._destroy()
|
||||||
|
_batch_invariant_MODE = False
|
||||||
|
_batch_invariant_LIB = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_batch_invariant_mode(enabled: bool = True):
|
||||||
|
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||||
|
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
|
||||||
|
if enabled:
|
||||||
|
enable_batch_invariant_mode()
|
||||||
|
else:
|
||||||
|
disable_batch_invariant_mode()
|
||||||
|
yield
|
||||||
|
if _batch_invariant_LIB is not None:
|
||||||
|
_batch_invariant_LIB._destroy()
|
||||||
|
_batch_invariant_MODE, _batch_invariant_LIB = old_data
|
||||||
|
|
||||||
|
|
||||||
|
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||||
|
return AttentionBlockSize(block_m=16, block_n=16)
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_kernel_override_batch_invariant():
|
||||||
|
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
|
||||||
|
is_overridden = False
|
||||||
|
val = os.getenv(env_key, "0")
|
||||||
|
try:
|
||||||
|
is_overridden = int(val) != 0
|
||||||
|
except ValueError:
|
||||||
|
is_overridden = False
|
||||||
|
return is_overridden
|
||||||
|
|
||||||
|
|
||||||
|
def init_batch_invariance():
|
||||||
|
# this will hit all the csrc overrides as well
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||||
|
enable_batch_invariant_mode()
|
||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Literal, Optional, Union, get_args, overload
|
from typing import Callable, Literal, Optional, Union, get_args, overload
|
||||||
|
|
||||||
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
|
|||||||
if dp_size is not None else get_dp_group().world_size)
|
if dp_size is not None else get_dp_group().world_size)
|
||||||
|
|
||||||
self.is_sequence_parallel = is_sequence_parallel
|
self.is_sequence_parallel = is_sequence_parallel
|
||||||
if self.is_sequence_parallel:
|
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||||
self.sp_size = tp_size_
|
|
||||||
|
|
||||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||||
FusedMoEParallelConfig.make(
|
FusedMoEParallelConfig.make(
|
||||||
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
|
|||||||
# clamp start and end
|
# clamp start and end
|
||||||
chunk_start = min(chunk_start, num_tokens - 1)
|
chunk_start = min(chunk_start, num_tokens - 1)
|
||||||
chunk_end = min(chunk_end, num_tokens)
|
chunk_end = min(chunk_end, num_tokens)
|
||||||
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
with ctx.dp_metadata.chunked_sizes(self.sp_size,
|
||||||
|
moe_dp_chunk_size_per_rank,
|
||||||
chunk_idx):
|
chunk_idx):
|
||||||
process_chunk(chunk_start,
|
process_chunk(chunk_start,
|
||||||
chunk_end,
|
chunk_end,
|
||||||
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
|
|||||||
else:
|
else:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
|
|
||||||
if do_naive_dispatch_combine:
|
ctx = get_forward_context()
|
||||||
hidden_states, router_logits = get_ep_group().dispatch(
|
sp_ctx = ctx.dp_metadata.sp_local_sizes(
|
||||||
hidden_states, router_logits)
|
self.sp_size) if ctx.dp_metadata else nullcontext()
|
||||||
|
|
||||||
# Matrix multiply.
|
with sp_ctx:
|
||||||
final_hidden_states = self.quant_method.apply(
|
if do_naive_dispatch_combine:
|
||||||
layer=self,
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
x=hidden_states,
|
hidden_states, router_logits, self.is_sequence_parallel)
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=self.top_k,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
|
||||||
global_num_experts=self.global_num_experts,
|
|
||||||
expert_map=self.expert_map,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
num_expert_group=self.num_expert_group,
|
|
||||||
custom_routing_function=self.custom_routing_function,
|
|
||||||
scoring_func=self.scoring_func,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
|
||||||
activation=self.activation,
|
|
||||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
||||||
enable_eplb=self.enable_eplb,
|
|
||||||
expert_load_view=self.expert_load_view,
|
|
||||||
logical_to_physical_map=self.logical_to_physical_map,
|
|
||||||
logical_replica_count=self.logical_replica_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
# Matrix multiply.
|
||||||
assert not isinstance(final_hidden_states, tuple)
|
final_hidden_states = self.quant_method.apply(
|
||||||
assert self.shared_experts is not None
|
layer=self,
|
||||||
final_hidden_states = (
|
x=hidden_states,
|
||||||
shared_output,
|
router_logits=router_logits,
|
||||||
final_hidden_states,
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
global_num_experts=self.global_num_experts,
|
||||||
|
expert_map=self.expert_map,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
activation=self.activation,
|
||||||
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
|
expert_load_view=self.expert_load_view,
|
||||||
|
logical_to_physical_map=self.logical_to_physical_map,
|
||||||
|
logical_replica_count=self.logical_replica_count,
|
||||||
)
|
)
|
||||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
|
||||||
assert isinstance(final_hidden_states, tuple)
|
|
||||||
final_hidden_states, zero_expert_result = final_hidden_states
|
|
||||||
|
|
||||||
def reduce_output(states: torch.Tensor,
|
if shared_output is not None:
|
||||||
do_combine: bool = True) -> torch.Tensor:
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
if do_naive_dispatch_combine and do_combine:
|
assert self.shared_experts is not None
|
||||||
states = get_ep_group().combine(states)
|
final_hidden_states = (
|
||||||
|
shared_output,
|
||||||
|
final_hidden_states,
|
||||||
|
)
|
||||||
|
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||||
|
assert isinstance(final_hidden_states, tuple)
|
||||||
|
final_hidden_states, zero_expert_result = final_hidden_states
|
||||||
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
def reduce_output(states: torch.Tensor,
|
||||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
do_combine: bool = True) -> torch.Tensor:
|
||||||
|
if do_naive_dispatch_combine and do_combine:
|
||||||
|
states = get_ep_group().combine(states,
|
||||||
|
self.is_sequence_parallel)
|
||||||
|
|
||||||
return states
|
if (not self.is_sequence_parallel and self.reduce_results
|
||||||
|
and (self.tp_size > 1 or self.ep_size > 1)):
|
||||||
|
states = self.maybe_all_reduce_tensor_model_parallel(
|
||||||
|
states)
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
return states
|
||||||
return (
|
|
||||||
reduce_output(final_hidden_states[0], do_combine=False),
|
if self.shared_experts is not None:
|
||||||
reduce_output(final_hidden_states[1]),
|
return (
|
||||||
)
|
reduce_output(final_hidden_states[0], do_combine=False),
|
||||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
reduce_output(final_hidden_states[1]),
|
||||||
assert isinstance(final_hidden_states, torch.Tensor)
|
)
|
||||||
return reduce_output(final_hidden_states) + zero_expert_result
|
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||||
else:
|
assert isinstance(final_hidden_states, torch.Tensor)
|
||||||
return reduce_output(final_hidden_states)
|
return reduce_output(final_hidden_states) + zero_expert_result
|
||||||
|
else:
|
||||||
|
return reduce_output(final_hidden_states)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_params_mapping(
|
def make_expert_params_mapping(
|
||||||
|
|||||||
@ -639,6 +639,19 @@ def runai_safetensors_weights_iterator(
|
|||||||
yield from tensor_iter
|
yield from tensor_iter
|
||||||
|
|
||||||
|
|
||||||
|
def _init_loader(
|
||||||
|
pg: torch.distributed.ProcessGroup,
|
||||||
|
device: torch.device,
|
||||||
|
f_list: list[str],
|
||||||
|
*,
|
||||||
|
nogds: bool = False,
|
||||||
|
):
|
||||||
|
loader = SafeTensorsFileLoader(pg, device, nogds=nogds)
|
||||||
|
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||||
|
loader.add_filenames(rank_file_map)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def fastsafetensors_weights_iterator(
|
def fastsafetensors_weights_iterator(
|
||||||
hf_weights_files: list[str],
|
hf_weights_files: list[str],
|
||||||
use_tqdm_on_load: bool,
|
use_tqdm_on_load: bool,
|
||||||
@ -656,17 +669,31 @@ def fastsafetensors_weights_iterator(
|
|||||||
for i in range(0, len(hf_weights_files), pg.size())
|
for i in range(0, len(hf_weights_files), pg.size())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
nogds = False
|
||||||
|
|
||||||
for f_list in tqdm(
|
for f_list in tqdm(
|
||||||
weight_files_sub_lists,
|
weight_files_sub_lists,
|
||||||
desc="Loading safetensors using Fastsafetensor loader",
|
desc="Loading safetensors using Fastsafetensor loader",
|
||||||
disable=not enable_tqdm(use_tqdm_on_load),
|
disable=not enable_tqdm(use_tqdm_on_load),
|
||||||
bar_format=_BAR_FORMAT,
|
bar_format=_BAR_FORMAT,
|
||||||
):
|
):
|
||||||
loader = SafeTensorsFileLoader(pg, device)
|
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
|
||||||
loader.add_filenames(rank_file_map)
|
|
||||||
try:
|
try:
|
||||||
fb = loader.copy_files_to_device()
|
try:
|
||||||
|
fb = loader.copy_files_to_device()
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "gds" not in str(e):
|
||||||
|
raise
|
||||||
|
|
||||||
|
loader.close()
|
||||||
|
nogds = True
|
||||||
|
logger.warning_once(
|
||||||
|
"GDS not enabled, setting `nogds=True`.\n"
|
||||||
|
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
|
||||||
|
)
|
||||||
|
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||||
|
fb = loader.copy_files_to_device()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keys = list(fb.key_to_rank_lidx.keys())
|
keys = list(fb.key_to_rank_lidx.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
|||||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||||
from transformers.models.aria.processing_aria import AriaProcessor
|
from transformers.models.aria.processing_aria import AriaProcessor
|
||||||
|
|
||||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
from vllm.config import QuantizationConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -38,8 +38,7 @@ from .idefics2_vision_model import (
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
||||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
is_pp_missing_parameter, maybe_prefix,
|
is_pp_missing_parameter, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class AriaImagePixelInputs(TensorSchema):
|
class AriaImagePixelInputs(TensorSchema):
|
||||||
@ -298,14 +297,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
|||||||
Experts (MoE) Layer.
|
Experts (MoE) Layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
self,
|
super().__init__(vllm_config, prefix)
|
||||||
config: AriaTextConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config = vllm_config.quant_config
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config, cache_config, quant_config, prefix)
|
|
||||||
self.mlp = AriaTextMoELayer(config,
|
self.mlp = AriaTextMoELayer(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp")
|
||||||
@ -605,19 +602,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
multimodal_embeddings = self._process_image_input(image_input)
|
multimodal_embeddings = self._process_image_input(image_input)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
|
||||||
self.config.image_token_index)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -628,10 +612,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
# always pass the input via `inputs_embeds`
|
inputs_embeds = self.get_input_embeddings(
|
||||||
# to make sure the computation graph is consistent
|
input_ids,
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
multimodal_embeddings,
|
||||||
multimodal_embeddings)
|
is_multimodal=input_ids == self.config.image_token_index,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model(
|
hidden_states = self.language_model(
|
||||||
|
|||||||
@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class AyaVisionImagePixelInputs(TensorSchema):
|
class AyaVisionImagePixelInputs(TensorSchema):
|
||||||
@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return self._process_image_input(image_input, **kwargs)
|
return self._process_image_input(image_input, **kwargs)
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
|
||||||
placeholder_token_id=self.config.image_token_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == self.config.image_token_index,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
|
|||||||
@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
|||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.pooler = self._build_pooler(pooler_config)
|
self.pooler = self._build_pooler(pooler_config)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.bert.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_params = loader.load_weights(weights)
|
loaded_params = loader.load_weights(weights)
|
||||||
@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
|
|||||||
Pooler.for_encode(pooler_config),
|
Pooler.for_encode(pooler_config),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.bert.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_params = loader.load_weights(weights)
|
loaded_params = loader.load_weights(weights)
|
||||||
|
|||||||
@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
|
|||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
loaded_params = loader.load_weights(weights)
|
loaded_params = loader.load_weights(weights)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.new.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from .blip import BlipVisionModel
|
|||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix)
|
||||||
|
|
||||||
# We use this internally as placeholders since there is no image token
|
# We use this internally as placeholders since there is no image token
|
||||||
# defined on the HuggingFace repo
|
# defined on the HuggingFace repo
|
||||||
@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
|
||||||
_IMAGE_TOKEN_ID)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
|||||||
@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
|||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
|
||||||
self.model.vocabulary_mapping.image_token_id)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||||
vision_embeddings)
|
inputs_embeds = self.get_input_embeddings(
|
||||||
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == image_token_id,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.model(input_ids,
|
hidden_states = self.model(input_ids,
|
||||||
|
|||||||
@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.transformer.make_empty_intermediate_tensors)
|
self.transformer.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.transformer.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class Cohere2VisionImagePixelInputs(TensorSchema):
|
class Cohere2VisionImagePixelInputs(TensorSchema):
|
||||||
@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return self._process_image_input(image_input, **kwargs)
|
return self._process_image_input(image_input, **kwargs)
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
|
||||||
placeholder_token_id=self.config.image_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == self.config.image_token_id,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
|
|||||||
@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.norm = RMSNorm(self.config.hidden_size,
|
self.norm = RMSNorm(self.config.hidden_size,
|
||||||
eps=self.config.rms_norm_eps)
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
|||||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||||
scale=logit_scale)
|
scale=logit_scale)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
|||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
prefix, "model"))
|
prefix, "model"))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -32,7 +32,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||||
@ -56,8 +55,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
@ -108,43 +107,6 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Chunk x along the num_tokens axis for sequence parallelism
|
|
||||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
|
||||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
|
||||||
# even though we explicitly pad to avoid this.
|
|
||||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# all_gather needs the sequence length to be divisible by tp_size
|
|
||||||
seq_len = x.size(0)
|
|
||||||
remainder = seq_len % tp_size
|
|
||||||
if remainder != 0:
|
|
||||||
pad_len = tp_size - remainder
|
|
||||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
|
||||||
|
|
||||||
chunk = x.shape[0] // tp_size
|
|
||||||
start = tp_rank * chunk
|
|
||||||
return torch.narrow(x, 0, start, chunk)
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
seq_len = cdiv(x.size(0), tp_size)
|
|
||||||
shape = list(x.shape)
|
|
||||||
shape[0] = seq_len
|
|
||||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="sequence_parallel_chunk",
|
|
||||||
op_func=sequence_parallel_chunk,
|
|
||||||
fake_impl=sequence_parallel_chunk_fake,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -166,20 +128,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.n_routed_experts: int = config.n_routed_experts
|
self.n_routed_experts: int = config.n_routed_experts
|
||||||
self.n_shared_experts: int = config.n_shared_experts
|
self.n_shared_experts: int = config.n_shared_experts
|
||||||
|
|
||||||
# The all_reduce at the end of attention (during o_proj) means that
|
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||||
# inputs are replicated across each rank of the tensor parallel group.
|
|
||||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
|
||||||
# tokens results in useless duplicate computation and communication.
|
|
||||||
#
|
|
||||||
# In this case, ensure the input to the experts is sequence parallel
|
|
||||||
# to avoid the excess work.
|
|
||||||
#
|
|
||||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
|
||||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
|
||||||
in ("deepep_high_throughput",
|
|
||||||
"deepep_low_latency")
|
|
||||||
and parallel_config.enable_expert_parallel
|
|
||||||
and self.tp_size > 1)
|
|
||||||
|
|
||||||
if config.hidden_act != "silu":
|
if config.hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||||
@ -278,8 +227,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# TODO: We can replace the all_reduce at the end of attn with a
|
# TODO: We can replace the all_reduce at the end of attn with a
|
||||||
# reduce_scatter instead of chunking here.
|
# reduce_scatter instead of chunking here.
|
||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|||||||
@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
# The image token id may be various
|
# The image token id may be various
|
||||||
_IMAGE_TOKEN = "<image>"
|
_IMAGE_TOKEN = "<image>"
|
||||||
@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
tokenizer = cached_tokenizer_from_config(model_config)
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
|
||||||
|
|
||||||
self.vision = self._init_vision_module(self.vision_config,
|
self.vision = self._init_vision_module(self.vision_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
|
||||||
self.image_token_id)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# condition is for v0 compatibility
|
# condition is for v0 compatibility
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == self.image_token_id,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model(input_ids,
|
||||||
|
|||||||
@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
|||||||
Qwen2VLProcessingInfo)
|
Qwen2VLProcessingInfo)
|
||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
maybe_prefix,
|
maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict
|
from vllm.multimodal.inputs import MultiModalDataDict
|
||||||
@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(self,
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
**kwargs: object) -> MultiModalEmbeddings:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return []
|
return []
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
self.config.image_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
|
elif inputs_embeds is None:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
if image_input is None:
|
inputs_embeds = self.get_input_embeddings(
|
||||||
inputs_embeds = None
|
input_ids,
|
||||||
else:
|
vision_embeddings,
|
||||||
assert input_ids is not None
|
is_multimodal=input_ids == self.config.image_token_id,
|
||||||
inputs_embeds = self.get_multimodal_embeddings(
|
)
|
||||||
input_ids,
|
input_ids = None
|
||||||
image_input=image_input,
|
|
||||||
)
|
|
||||||
input_ids = None
|
|
||||||
|
|
||||||
hidden_states = self.language_model(
|
hidden_states = self.language_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
merge_multimodal_embeddings)
|
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if multimodal_embeddings is not None and len(
|
||||||
|
multimodal_embeddings) > 0:
|
||||||
|
self._set_visual_token_mask(input_ids)
|
||||||
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
# This is to satisfy the type checker for each overload
|
||||||
|
if multimodal_embeddings is None or is_multimodal is None:
|
||||||
|
return super().get_input_embeddings(input_ids)
|
||||||
|
|
||||||
if multimodal_embeddings is None:
|
return super().get_input_embeddings(
|
||||||
return inputs_embeds
|
input_ids,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
self._set_visual_token_mask(input_ids)
|
is_multimodal=is_multimodal,
|
||||||
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
multimodal_embeddings,
|
)
|
||||||
[self.config.im_patch_id])
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -29,10 +29,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
self.mtp_emb_norm = RMSNorm(config.hidden_size,
|
self.mtp_emb_norm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
|||||||
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
|
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
|
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
|
||||||
prefix)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
|
|||||||
self.layers = torch.nn.ModuleDict({
|
self.layers = torch.nn.ModuleDict({
|
||||||
str(idx):
|
str(idx):
|
||||||
ErnieMultiTokenPredictorLayer(
|
ErnieMultiTokenPredictorLayer(
|
||||||
config,
|
vllm_config,
|
||||||
f"{prefix}.layers.{idx}",
|
f"{prefix}.layers.{idx}",
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
)
|
)
|
||||||
for idx in range(self.mtp_start_layer_idx,
|
for idx in range(self.mtp_start_layer_idx,
|
||||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||||
@ -116,6 +110,9 @@ class ErnieMultiTokenPredictor(nn.Module):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -160,6 +157,9 @@ class ErnieMTP(nn.Module, SupportsPP):
|
|||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
# Cannot find the following 2 numbers from hf config.
|
# Cannot find the following 2 numbers from hf config.
|
||||||
_IMAGE_TOKEN_ID = 71011
|
_IMAGE_TOKEN_ID = 71011
|
||||||
@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return self._process_image_input(image_input)
|
return self._process_image_input(image_input)
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
_IMAGE_TOKEN_ID,
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model(
|
hidden_states = self.language_model(
|
||||||
|
|||||||
@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
return self._process_image_input(image_input)
|
return self._process_image_input(image_input)
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
self.config.image_token_index,
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == self.config.image_token_index,
|
||||||
|
)
|
||||||
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
|
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
|
||||||
kwargs = self.prepare_attn_masks(
|
kwargs = self.prepare_attn_masks(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|||||||
@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||||
# them here, as the model forward has only access to the input_embeds.
|
# them here, as the model forward has only access to the input_embeds.
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
||||||
per_layer_inputs)
|
per_layer_inputs)
|
||||||
|
|
||||||
if multimodal_embeddings is not None \
|
# This is to satisfy the type checker for each overload
|
||||||
and len(multimodal_embeddings) != 0:
|
if multimodal_embeddings is None or is_multimodal is None:
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
return super().get_input_embeddings(input_ids)
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
return super().get_input_embeddings(
|
||||||
multimodal_embeddings,
|
input_ids,
|
||||||
# NOTE: this order of processing mm items is important
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
[self.config.image_token_id, self.config.audio_token_id])
|
is_multimodal=is_multimodal,
|
||||||
return inputs_embeds
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
|
|||||||
|
|
||||||
class Glm4DecoderLayer(nn.Module):
|
class Glm4DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
vllm_config: VllmConfig,
|
||||||
config: Glm4Config,
|
prefix: str = "",
|
||||||
cache_config: Optional[CacheConfig] = None,
|
config: Optional[Glm4Config] = None) -> None:
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
config = config or vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
|||||||
@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
multimodal_embeddings += video_embeddings
|
multimodal_embeddings += video_embeddings
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if (multimodal_embeddings is not None
|
|
||||||
and len(multimodal_embeddings) != 0
|
|
||||||
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
[self.config.image_token_id, self.config.video_token_id],
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def get_input_embeddings_v0(
|
def get_input_embeddings_v0(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
|
|||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
|
|||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
prefix, "model"))
|
prefix, "model"))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
from .utils import flatten_bn, isin_list
|
||||||
|
|
||||||
|
|
||||||
class GLMVImagePixelInputs(TensorSchema):
|
class GLMVImagePixelInputs(TensorSchema):
|
||||||
@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
|
||||||
placeholder_token_id=[
|
|
||||||
self.config.boi_token_id,
|
|
||||||
self.config.pad_token_id,
|
|
||||||
self.config.eoi_token_id,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=isin_list(input_ids, [
|
||||||
|
self.config.boi_token_id,
|
||||||
|
self.config.pad_token_id,
|
||||||
|
self.config.eoi_token_id,
|
||||||
|
]),
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.transformer(input_ids, positions,
|
hidden_states = self.transformer(input_ids, positions,
|
||||||
|
|||||||
@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
|
|||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GptOssConfig,
|
vllm_config: VllmConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
quant_config: QuantizationConfig,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
|
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.experts_per_token = config.num_experts_per_tok
|
self.experts_per_token = config.num_experts_per_tok
|
||||||
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input=False,
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation="swigluoai")
|
activation="swigluoai",
|
||||||
|
is_sequence_parallel=self.is_sequence_parallel)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
if self.is_sequence_parallel:
|
||||||
|
x = sequence_parallel_chunk(x)
|
||||||
|
|
||||||
g = self.router(x)
|
g = self.router(x)
|
||||||
x = self.experts(hidden_states=x, router_logits=g)
|
x = self.experts(hidden_states=x, router_logits=g)
|
||||||
|
|
||||||
|
if self.is_sequence_parallel:
|
||||||
|
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||||
|
x = x[:num_tokens]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GptOssConfig,
|
vllm_config: VllmConfig,
|
||||||
cache_config: CacheConfig,
|
|
||||||
quant_config: QuantizationConfig,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
self.layer_idx = extract_layer_index(prefix)
|
self.layer_idx = extract_layer_index(prefix)
|
||||||
self.attn = OAIAttention(config,
|
self.attn = OAIAttention(config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
cache_config=cache_config)
|
cache_config=cache_config)
|
||||||
self.mlp = MLPBlock(config,
|
self.mlp = MLPBlock(vllm_config,
|
||||||
self.layer_idx,
|
self.layer_idx,
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp")
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.quant_config = vllm_config.quant_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.config.hidden_size = self.config.hidden_size
|
self.config.hidden_size = self.config.hidden_size
|
||||||
self.embedding = VocabParallelEmbedding(
|
self.embedding = VocabParallelEmbedding(
|
||||||
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
self.config.num_hidden_layers,
|
self.config.num_hidden_layers,
|
||||||
lambda prefix: TransformerBlock(
|
lambda prefix: TransformerBlock(
|
||||||
self.config,
|
vllm_config,
|
||||||
cache_config=self.cache_config,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
),
|
),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
|
|||||||
@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .blip2 import Blip2QFormerModel
|
from .blip2 import Blip2QFormerModel
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, embed_multimodal,
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||||
init_vllm_registered_model, maybe_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
### Audio Input
|
### Audio Input
|
||||||
@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
# Split variable length features into a tuple
|
# Split variable length features into a tuple
|
||||||
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
return self.language_model
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self,
|
self,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
if audio_input is None:
|
if audio_input is None:
|
||||||
return []
|
return []
|
||||||
return None
|
|
||||||
audio_features = self._process_audio_input(audio_input)
|
audio_features = self._process_audio_input(audio_input)
|
||||||
return audio_features
|
return audio_features
|
||||||
|
|
||||||
@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
# Multi-modal token ID may exceed vocab size
|
||||||
|
handle_oov_mm_token: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute the merged LLM / audio embeddings."""
|
# This is to satisfy the type checker for each overload
|
||||||
if multimodal_embeddings is None \
|
if multimodal_embeddings is None or is_multimodal is None:
|
||||||
or len(multimodal_embeddings) == 0:
|
return super().get_input_embeddings(input_ids)
|
||||||
return self.language_model.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
inputs_embeds = embed_multimodal(
|
return super().get_input_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
self.config.audio_token_index,
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
self.language_model.model.get_input_embeddings,
|
is_multimodal=is_multimodal,
|
||||||
multimodal_embeddings,
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
|
inputs_embeds = self.get_input_embeddings(
|
||||||
|
input_ids,
|
||||||
|
audio_embeds,
|
||||||
|
is_multimodal=input_ids == self.config.audio_token_index,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
model_output = self.language_model(input_ids, positions,
|
model_output = self.language_model(input_ids, positions,
|
||||||
|
|||||||
@ -29,12 +29,13 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.granitemoe import GraniteMoeConfig
|
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import (get_pp_group,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
is_sequence_parallel=False,
|
||||||
prefix: str = ""):
|
prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
self.is_sequence_parallel = is_sequence_parallel
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(hidden_size,
|
self.gate = ReplicatedLinear(hidden_size,
|
||||||
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
prefix=f"{prefix}.experts")
|
prefix=f"{prefix}.experts",
|
||||||
|
is_sequence_parallel=self.is_sequence_parallel)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
|
||||||
|
if self.is_sequence_parallel:
|
||||||
|
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
|
|
||||||
|
if self.is_sequence_parallel:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_gather(
|
||||||
|
final_hidden_states, 0)
|
||||||
|
num_tokens = orig_shape[0]
|
||||||
|
final_hidden_states = final_hidden_states[:num_tokens]
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: GraniteMoeConfig,
|
vllm_config: VllmConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
|
||||||
prefix=f"{prefix}.block_sparse_moe")
|
prefix=f"{prefix}.block_sparse_moe")
|
||||||
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
|
|||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: GraniteMoeDecoderLayer(
|
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
|
||||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
|
||||||
),
|
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|||||||
@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
|||||||
moe.n_redundant_experts = self.num_redundant_experts
|
moe.n_redundant_experts = self.num_redundant_experts
|
||||||
moe.experts.update_expert_map()
|
moe.experts.update_expert_map()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import get_vision_encoder_info
|
||||||
|
|
||||||
EOT = "<|endofturn|>"
|
EOT = "<|endofturn|>"
|
||||||
@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self,
|
self,
|
||||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
||||||
) -> Optional[MultiModalEmbeddings]:
|
) -> MultiModalEmbeddings:
|
||||||
|
|
||||||
multimodal_embeddings = list()
|
multimodal_embeddings = list()
|
||||||
if kwargs.get("pixel_values_images") is not None:
|
if kwargs.get("pixel_values_images") is not None:
|
||||||
@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
placeholder_token_id=[
|
|
||||||
self.config.image_token_id,
|
|
||||||
self.config.video_token_id,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
multimodal_embeddings)
|
input_ids,
|
||||||
|
multimodal_embeddings,
|
||||||
|
is_multimodal=isin_list(
|
||||||
|
input_ids,
|
||||||
|
[self.config.image_token_id, self.config.video_token_id]),
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -52,8 +52,7 @@ from .idefics2_vision_model import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||||
from .llama import LlamaModel
|
from .llama import LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImagePixelInputs(TensorSchema):
|
class Idefics3ImagePixelInputs(TensorSchema):
|
||||||
@ -539,10 +538,7 @@ class Idefics3Model(nn.Module):
|
|||||||
|
|
||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.text_model.get_input_embeddings(input_ids)
|
return self.text_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return self._process_image_input(image_input)
|
return self._process_image_input(image_input)
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None \
|
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
self.config.image_token_id,
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=input_ids == self.config.image_token_id,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.model.text_model(input_ids,
|
hidden_states = self.model.text_model(input_ids,
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping, MutableSequence
|
from collections.abc import Iterable, Mapping, MutableSequence
|
||||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional,
|
||||||
Union, overload, runtime_checkable)
|
Protocol, Union, overload, runtime_checkable)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.utils import supports_kw
|
from vllm.utils import supports_kw
|
||||||
|
|
||||||
from .interfaces_base import is_pooling_model
|
from .interfaces_base import VllmModel, is_pooling_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> VllmModel:
|
||||||
"""
|
"""
|
||||||
Returns the underlying language model used for text generation.
|
Returns the underlying language model used for text generation.
|
||||||
|
|
||||||
@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
multimodal_embeddings: MultiModalEmbeddings,
|
||||||
|
*,
|
||||||
|
is_multimodal: torch.Tensor,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
def _get_text_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
get_input_embeddings: Callable[[Tensor], Tensor],
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[Tensor],
|
||||||
|
handle_oov_mm_token: bool,
|
||||||
|
) -> Tensor:
|
||||||
|
if handle_oov_mm_token and is_multimodal is not None:
|
||||||
|
is_text = ~is_multimodal
|
||||||
|
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||||
|
|
||||||
|
return torch.empty(
|
||||||
|
(input_ids.shape[0], text_embeds.shape[1]),
|
||||||
|
dtype=text_embeds.dtype,
|
||||||
|
device=text_embeds.device,
|
||||||
|
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||||
|
|
||||||
|
return get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[Tensor] = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the input embeddings merged from the text embeddings from
|
Apply token embeddings to `input_ids`.
|
||||||
input_ids and the multimodal embeddings generated from multimodal
|
|
||||||
kwargs.
|
If `multimodal_embeddings` is passed, scatter them into
|
||||||
|
`input_ids` according to the mask `is_multimodal`.
|
||||||
|
|
||||||
|
In case the multi-modal token IDs exceed the vocabulary size of
|
||||||
|
the language model, you can set `handle_oov_mm_token=False`
|
||||||
|
to avoid calling the language model's `get_input_embeddings` method
|
||||||
|
on those tokens. Note however that doing so increases memory usage
|
||||||
|
as an additional buffer is needed to hold the input embeddings.
|
||||||
"""
|
"""
|
||||||
...
|
from .utils import _merge_multimodal_embeddings
|
||||||
|
|
||||||
|
inputs_embeds = self._get_text_embeddings(
|
||||||
|
input_ids,
|
||||||
|
self.get_language_model().get_input_embeddings,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
if is_multimodal is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||||
|
"please update your model runner according to "
|
||||||
|
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||||
|
|
||||||
|
return _merge_multimodal_embeddings(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|||||||
@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply token embeddings to `input_ids`."""
|
||||||
|
...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
|||||||
return supports_kw(model_init, "vllm_config")
|
return supports_kw(model_init, "vllm_config")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_vllm_model_get_input_embeddings(
|
||||||
|
model: Union[type[object], object]) -> bool:
|
||||||
|
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||||
|
if not callable(model_get_input_embeddings):
|
||||||
|
logger.warning(
|
||||||
|
"The model (%s) is missing the `get_input_embeddings` method.",
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||||
model_forward = getattr(model, "forward", None)
|
model_forward = getattr(model, "forward", None)
|
||||||
if not callable(model_forward):
|
if not callable(model_forward):
|
||||||
@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
|||||||
def is_vllm_model(
|
def is_vllm_model(
|
||||||
model: Union[type[object], object],
|
model: Union[type[object], object],
|
||||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||||
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
return (_check_vllm_model_init(model)
|
||||||
|
and _check_vllm_model_get_input_embeddings(model)
|
||||||
|
and _check_vllm_model_forward(model))
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|||||||
@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, isin_list, maybe_prefix)
|
||||||
merge_multimodal_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class InternS1MultiModalProjector(nn.Module):
|
class InternS1MultiModalProjector(nn.Module):
|
||||||
@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
if multimodal_embeddings is not None and len(
|
||||||
if multimodal_embeddings is not None \
|
multimodal_embeddings) > 0:
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
context_token_ids = [
|
|
||||||
token_id for token_id in (self.img_context_token_id,
|
|
||||||
self.video_context_token_id)
|
|
||||||
if token_id is not None
|
|
||||||
]
|
|
||||||
assert len(context_token_ids) >= 1
|
|
||||||
self._set_visual_token_mask(input_ids)
|
self._set_visual_token_mask(input_ids)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
# This is to satisfy the type checker for each overload
|
||||||
inputs_embeds,
|
if multimodal_embeddings is None or is_multimodal is None:
|
||||||
multimodal_embeddings,
|
return super().get_input_embeddings(input_ids)
|
||||||
context_token_ids,
|
|
||||||
)
|
return super().get_input_embeddings(
|
||||||
return inputs_embeds
|
input_ids,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
context_token_ids = [
|
||||||
|
token_id for token_id in (self.img_context_token_id,
|
||||||
|
self.video_context_token_id)
|
||||||
|
if token_id is not None
|
||||||
|
]
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
forward_kwargs = {
|
forward_kwargs = {
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
isin_list, maybe_prefix)
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
if multimodal_embeddings is not None and len(
|
||||||
if multimodal_embeddings is not None \
|
multimodal_embeddings) > 0:
|
||||||
and len(multimodal_embeddings) != 0:
|
|
||||||
context_token_ids = [
|
|
||||||
token_id for token_id in (self.img_context_token_id,
|
|
||||||
self.video_context_token_id)
|
|
||||||
if token_id is not None
|
|
||||||
]
|
|
||||||
assert len(context_token_ids) >= 1
|
|
||||||
self._set_visual_token_mask(input_ids)
|
self._set_visual_token_mask(input_ids)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
# This is to satisfy the type checker for each overload
|
||||||
inputs_embeds,
|
if multimodal_embeddings is None or is_multimodal is None:
|
||||||
multimodal_embeddings,
|
return super().get_input_embeddings(input_ids)
|
||||||
context_token_ids,
|
|
||||||
)
|
return super().get_input_embeddings(
|
||||||
return inputs_embeds
|
input_ids,
|
||||||
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
handle_oov_mm_token=handle_oov_mm_token,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
context_token_ids = [
|
||||||
|
token_id for token_id in (self.img_context_token_id,
|
||||||
|
self.video_context_token_id)
|
||||||
|
if token_id is not None
|
||||||
|
]
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(
|
||||||
vision_embeddings)
|
input_ids,
|
||||||
|
vision_embeddings,
|
||||||
|
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
forward_kwargs = {
|
forward_kwargs = {
|
||||||
|
|||||||
@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module):
|
|||||||
multimodal_embeddings += video_embeddings
|
multimodal_embeddings += video_embeddings
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
[
|
|
||||||
self.config.image_token_id,
|
|
||||||
self.config.video_token_id,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def get_input_embeddings_v0(
|
def get_input_embeddings_v0(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user