mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 09:07:03 +08:00
Merge branch 'main' into mlm-full-lora-support
This commit is contained in:
commit
94dce5c3d9
@ -1223,6 +1223,8 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# Alot of these tests are on the edge of OOMing
|
||||
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error(
|
||||
# Hugging Face model identifier with download error
|
||||
path = "org/repo"
|
||||
mock_exist.return_value = False
|
||||
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
|
||||
mock_snapshot_download.side_effect = HfHubHTTPError(
|
||||
"failed to query model info",
|
||||
response=MagicMock(),
|
||||
)
|
||||
assert get_adapter_absolute_path(path) == path
|
||||
|
||||
@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
||||
"mm_encoder_attn_backend",
|
||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||
)
|
||||
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
|
||||
@create_new_process_for_each_test()
|
||||
def test_vit_backend_functionality(
|
||||
model_key: str,
|
||||
|
||||
@ -642,48 +642,130 @@ _OPS_REGISTERED = False
|
||||
|
||||
|
||||
class rocm_aiter_ops:
|
||||
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
|
||||
|
||||
This class centralizes the import and registration of AITER ops,
|
||||
and provides a unified interface for checking if AITER is enabled.
|
||||
Operations are only available on supported gfx9
|
||||
architectures when aiter is installed.
|
||||
|
||||
The class uses environment variables to control which features are enabled,
|
||||
allowing fine-grained control over which AITER optimizations are used.
|
||||
|
||||
Environment Variables:
|
||||
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
|
||||
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
|
||||
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
|
||||
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
|
||||
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
|
||||
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
|
||||
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
|
||||
|
||||
Note:
|
||||
The environment variables are assigned when the module is imported,
|
||||
so you can't change the environment variables after the module is imported.
|
||||
This is done out of performance consideration. Accessing environment variables
|
||||
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
|
||||
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
|
||||
You can call the refresh_env_variables() function to reload the env variables
|
||||
after monkey patching the env variables in the unit test.
|
||||
|
||||
Check Functions:
|
||||
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
|
||||
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
|
||||
(3) aiter library is installed. The check function then also verifies
|
||||
the corresponding environment variable is enabled.
|
||||
i.e. ___
|
||||
is_enabled() == current_platform.is_rocm() and | checked by
|
||||
current_platform.is_on_gfx9() and | @if_aiter_supported
|
||||
IS_AITER_FOUND and _______________|
|
||||
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
|
||||
|
||||
Example:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
# Check if aiter is enabled before using operations
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
|
||||
|
||||
Operations:
|
||||
- RMS normalization: rms_norm, rms_norm2d_with_add
|
||||
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
|
||||
- Fused MoE: fused_moe, asm_moe_tkw1
|
||||
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
|
||||
- MLA decode: mla_decode_fwd
|
||||
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
|
||||
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
|
||||
"""
|
||||
|
||||
# Check if the env variable is set
|
||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@classmethod
|
||||
def refresh_env_variables(cls):
|
||||
"""
|
||||
Since the environment variables are assigned when the module is imported,
|
||||
This is a helper function to reload all the env variables from
|
||||
the environment variables.
|
||||
for example, after monkey patching the env variables in the unit test,
|
||||
you can call this function to reload the env variables.
|
||||
"""
|
||||
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
"""Verifies device specs and availability of aiter main env variable."""
|
||||
return cls._AITER_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls.is_linear_enabled()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_rmsnorm_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fused_moe_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@ -694,25 +776,16 @@ class rocm_aiter_ops:
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mla_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mha_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_pa_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -937,7 +937,7 @@ class CompilationConfig:
|
||||
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using piecewise compilation with empty splitting_ops"
|
||||
"Using piecewise cudagraph with empty splitting_ops"
|
||||
)
|
||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
||||
logger.warning_once(
|
||||
|
||||
@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE):
|
||||
self._shared_experts = shared_experts
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb, because of correctness issues
|
||||
# - we are using flashinfer with DP, since there nothing to gain
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
# - we are using flashinfer with DP, since there nothint to gain
|
||||
# - we are using marlin kernels
|
||||
backend = self.moe_parallel_config.all2all_backend
|
||||
self.use_overlapped = (
|
||||
use_overlapped
|
||||
and not (
|
||||
# TODO(wentao): find the root cause and remove this condition
|
||||
self.enable_eplb
|
||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
|
||||
@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
layer.w13_weight = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||
layer.w2_weight_scale = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up weights that won't be used by TRT-LLM
|
||||
del layer.w2_weight
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
else:
|
||||
# swizzle weight scales
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
|
||||
@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
|
||||
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
layer.w13_weight = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||
gemm2_weights_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||
layer.w2_weight_scale = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False
|
||||
)
|
||||
|
||||
@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up weights that won't be used by TRT-LLM
|
||||
del layer.w2_weight
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
elif self.use_marlin:
|
||||
# Marlin processing
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
|
||||
@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe(
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn
|
||||
),
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
|
||||
@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
||||
)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
"Only 'absolute' position_embedding_type" + " is supported"
|
||||
|
||||
@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
|
||||
torch.arange(config.max_position_embeddings).unsqueeze(0),
|
||||
)
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
"Only 'absolute' position_embedding_type" + " is supported"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
def _build_model(
|
||||
self, vllm_config: VllmConfig, prefix: str = ""
|
||||
) -> BertModel | BertWithRope:
|
||||
if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
|
||||
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
kwargs = dict(vllm_config=vllm_config, prefix=prefix)
|
||||
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
|
||||
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
|
||||
else:
|
||||
return BertModel(
|
||||
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
|
||||
)
|
||||
return JinaRobertaModel(**kwargs)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
|
||||
@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states, attention_mask, head_mask, output_attentions
|
||||
)
|
||||
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
@ -339,18 +335,14 @@ class SwinStage(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
height, width = input_dimensions
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
@ -473,7 +461,6 @@ class SwinModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
@ -481,7 +468,6 @@ class SwinModel(nn.Module):
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from types import SimpleNamespace
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
|
||||
)
|
||||
hidden_states = hidden_states + positions
|
||||
|
||||
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||
# was a required argument for WhisperEncoderLayer.forward
|
||||
kwargs = {}
|
||||
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||
kwargs["layer_head_mask"] = None
|
||||
|
||||
for layer in self.layers:
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
layer_head_mask=None,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
|
||||
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
|
||||
|
||||
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||
# was a required argument for WhisperEncoderLayer.forward
|
||||
kwargs = {}
|
||||
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||
kwargs["layer_head_mask"] = None
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
# The dtype of MLA out tensor
|
||||
attn_out_dtype: torch.dtype = torch.bfloat16
|
||||
# The max query output length: int
|
||||
max_qo_len: int | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
self.qo_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_qo_len = qo_len.max().item()
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
self.qo_indptr[: 1 + num_reqs].copy_(
|
||||
query_start_loc_device, non_blocking=True
|
||||
)
|
||||
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
max_qo_len=max_qo_len,
|
||||
attn_out_dtype=self.decode_attn_out_dtype,
|
||||
)
|
||||
|
||||
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
max_seqlen_qo,
|
||||
attn_metadata.decode.max_qo_len,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
|
||||
@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface):
|
||||
if self.is_encoder_decoder
|
||||
else EncoderCacheManager(cache_size=encoder_cache_size)
|
||||
)
|
||||
# For encoder-decoder models, allocate the maximum number of tokens for Cross
|
||||
# Attn blocks, as for Whisper its input is always padded to the maximum length.
|
||||
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
|
||||
self._num_encoder_max_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
|
||||
)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
self.use_eagle = False
|
||||
@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface):
|
||||
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
|
||||
)
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens = (
|
||||
self.scheduler_config.max_num_encoder_input_tokens
|
||||
)
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
num_encoder_tokens = (
|
||||
self._num_encoder_max_input_tokens
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs
|
||||
else 0
|
||||
)
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
|
||||
@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import transformers.file_utils as file_utils
|
||||
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
||||
import xgrammar as xgr
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
@ -30,10 +30,8 @@ else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
||||
tokenization_gpt2 = LazyLoader(
|
||||
"tokenization_gpt2",
|
||||
globals(),
|
||||
"transformers.models.gpt2.tokenization_gpt2",
|
||||
bytes_to_unicode = LazyLoader(
|
||||
"bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer"
|
||||
)
|
||||
|
||||
TokenizerLike = object
|
||||
@ -204,7 +202,7 @@ def _reduced_vocabulary(
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
|
||||
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()}
|
||||
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
@ -145,12 +145,20 @@ class WorkspaceManager:
|
||||
|
||||
for ubatch_id in range(self._num_ubatches):
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
if current_workspace is None:
|
||||
if (
|
||||
current_workspace is None
|
||||
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||
):
|
||||
# Delete old tensor before allocating new one to avoid
|
||||
# memory spike from resize_(). resize_() allocates new
|
||||
# memory before freeing old, which can cause OOM.
|
||||
# Must clear the list reference first since local var
|
||||
# is just a copy of the reference.
|
||||
self._current_workspaces[ubatch_id] = None
|
||||
del current_workspace
|
||||
self._current_workspaces[ubatch_id] = torch.empty(
|
||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||
)
|
||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
||||
current_workspace.resize_(required_bytes)
|
||||
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user