Merge branch 'main' into mlm-full-lora-support

This commit is contained in:
Jee Jee Li 2025-12-17 00:33:42 +08:00 committed by GitHub
commit 94dce5c3d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 182 additions and 123 deletions

View File

@ -1223,6 +1223,8 @@ steps:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Alot of these tests are on the edge of OOMing
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py

View File

@ -3,7 +3,7 @@
from collections import OrderedDict
from typing import NamedTuple
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from huggingface_hub.utils import HfHubHTTPError
@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error(
# Hugging Face model identifier with download error
path = "org/repo"
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info",
response=MagicMock(),
)
assert get_adapter_absolute_path(path) == path

View File

@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(),
)
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
@create_new_process_for_each_test()
def test_vit_backend_functionality(
model_key: str,

View File

@ -642,48 +642,130 @@ _OPS_REGISTERED = False
class rocm_aiter_ops:
"""ROCm AITER operations wrapper for AMD GPU acceleration in vLLM.
This class centralizes the import and registration of AITER ops,
and provides a unified interface for checking if AITER is enabled.
Operations are only available on supported gfx9
architectures when aiter is installed.
The class uses environment variables to control which features are enabled,
allowing fine-grained control over which AITER optimizations are used.
Environment Variables:
VLLM_ROCM_USE_AITER: Main toggle for all AITER operations.
VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops.
VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations.
VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops.
VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops.
VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention.
VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply.
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM.
VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings.
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion.
VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM.
Note:
The environment variables are assigned when the module is imported,
so you can't change the environment variables after the module is imported.
This is done out of performance consideration. Accessing environment variables
is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067
so we don't want to do it repeatedly, especially in the hot path (the forward pass).
You can call the refresh_env_variables() function to reload the env variables
after monkey patching the env variables in the unit test.
Check Functions:
All check functions (is_*_enabled) are decorated with @if_aiter_supported,
which verifies: (1) platform is ROCm, (2) device arch is gfx9, and
(3) aiter library is installed. The check function then also verifies
the corresponding environment variable is enabled.
i.e. ___
is_enabled() == current_platform.is_rocm() and | checked by
current_platform.is_on_gfx9() and | @if_aiter_supported
IS_AITER_FOUND and _______________|
cls._AITER_ENABLED -----> Check by the logic in `is_enabled()`
Example:
from vllm._aiter_ops import rocm_aiter_ops
# Check if aiter is enabled before using operations
if rocm_aiter_ops.is_enabled():
result = rocm_aiter_ops.rms_norm(x, weight, epsilon)
Operations:
- RMS normalization: rms_norm, rms_norm2d_with_add
- GEMM operations: gemm_a8w8, gemm_a8w8_blockscale
- Fused MoE: fused_moe, asm_moe_tkw1
- Routing: topk_softmax, biased_grouped_topk, grouped_topk
- MLA decode: mla_decode_fwd
- Quantization: per_tensor_quant, per_token_quant, group_fp8_quant
- Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale
"""
# Check if the env variable is set
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod
def refresh_env_variables(cls):
"""
Since the environment variables are assigned when the module is imported,
This is a helper function to reload all the env variables from
the environment variables.
for example, after monkey patching the env variables in the unit test,
you can call this function to reload the env variables.
"""
cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED
@classmethod
@if_aiter_supported
def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled()
@classmethod
@if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod
@if_aiter_supported
def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod
@ -694,25 +776,16 @@ class rocm_aiter_ops:
@classmethod
@if_aiter_supported
def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod
@if_aiter_supported
def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod

View File

@ -937,7 +937,7 @@ class CompilationConfig:
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
):
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
"Using piecewise cudagraph with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(

View File

@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE):
self._shared_experts = shared_experts
# Disable shared expert overlap if:
# - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = (
use_overlapped
and not (
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
(self.enable_eplb and backend != "allgather_reducescatter")
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None

View File

@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter(
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(

View File

@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter(
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
elif self.use_marlin:
# Marlin processing
prepare_moe_fp4_layer_for_marlin(layer)

View File

@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,

View File

@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"

View File

@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward(
self,
input_ids: torch.Tensor,
@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(
self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
hf_config = vllm_config.model_config.hf_config
kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else:
return BertModel(
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
return JinaRobertaModel(**kwargs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)

View File

@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)

View File

@ -5,6 +5,7 @@
"""PyTorch Ultravox model."""
import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
)
hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -124,8 +124,6 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)

View File

@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int
max_qo_len: int | None = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
def __init__(
self,
@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs, dtype=torch.int32, device=device
)
self.qo_indptr = torch.arange(
0, max_num_reqs + 1, dtype=torch.int32, device=device
self.qo_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
)
def _build_decode(
@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
]
)
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_qo_len = qo_len.max().item()
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True
)
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
qo_indptr = self.qo_indptr[: 1 + num_reqs]
else:
@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype,
)
@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,
self.scale,
attn_metadata.decode.qo_indptr,
max_seqlen_qo,
attn_metadata.decode.max_qo_len,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,

View File

@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface):
if self.is_encoder_decoder
else EncoderCacheManager(cache_size=encoder_cache_size)
)
# For encoder-decoder models, allocate the maximum number of tokens for Cross
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self._num_encoder_max_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface):
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens = (
self.scheduler_config.max_num_encoder_input_tokens
)
else:
num_encoder_tokens = 0
num_encoder_tokens = (
self._num_encoder_max_input_tokens
if self.is_encoder_decoder and request.has_encoder_inputs
else 0
)
new_blocks = self.kv_cache_manager.allocate_slots(
request,

View File

@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING:
import outlines_core as oc
import transformers.file_utils as file_utils
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr
from transformers.convert_slow_tokenizer import bytes_to_unicode
from vllm.tokenizers import TokenizerLike
from vllm.v1.worker.gpu_input_batch import InputBatch
@ -30,10 +30,8 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
oc = LazyLoader("oc", globals(), "outlines_core")
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
tokenization_gpt2 = LazyLoader(
"tokenization_gpt2",
globals(),
"transformers.models.gpt2.tokenization_gpt2",
bytes_to_unicode = LazyLoader(
"bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer"
)
TokenizerLike = object
@ -204,7 +202,7 @@ def _reduced_vocabulary(
A Dict of token string -> equivalent token ids
"""
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()}
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token])

View File

@ -145,12 +145,20 @@ class WorkspaceManager:
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None:
if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(