mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
This commit is contained in:
parent
c3dfb0f6dd
commit
c214d699fd
@ -54,6 +54,7 @@ def parse_args():
|
|||||||
"--method",
|
"--method",
|
||||||
type=str,
|
type=str,
|
||||||
default="eagle",
|
default="eagle",
|
||||||
|
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||||
)
|
)
|
||||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||||
@ -118,9 +119,9 @@ def main(args):
|
|||||||
"prompt_lookup_max": args.prompt_lookup_max,
|
"prompt_lookup_max": args.prompt_lookup_max,
|
||||||
"prompt_lookup_min": args.prompt_lookup_min,
|
"prompt_lookup_min": args.prompt_lookup_min,
|
||||||
}
|
}
|
||||||
elif args.method.endswith("mtp"):
|
elif args.method == "mtp":
|
||||||
speculative_config = {
|
speculative_config = {
|
||||||
"method": args.method,
|
"method": "mtp",
|
||||||
"num_speculative_tokens": args.num_spec_tokens,
|
"num_speculative_tokens": args.num_spec_tokens,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
|
|||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
MTP_SIMILARITY_RATE = 0.8
|
||||||
|
|
||||||
|
|
||||||
def get_test_prompts(mm_enabled: bool):
|
def get_test_prompts(mm_enabled: bool):
|
||||||
prompt_types = ["repeat", "sentence"]
|
prompt_types = ["repeat", "sentence"]
|
||||||
@ -222,3 +224,66 @@ def test_eagle_correctness(
|
|||||||
del spec_llm
|
del spec_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||||
|
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||||
|
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||||
|
],
|
||||||
|
ids=["mimo", "deepseek"])
|
||||||
|
def test_mtp_correctness(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_setup: tuple[str, str, int],
|
||||||
|
mm_enabled: bool,
|
||||||
|
):
|
||||||
|
# Generate test prompts inside the function instead of using fixture
|
||||||
|
test_prompts = get_test_prompts(mm_enabled)
|
||||||
|
'''
|
||||||
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
|
should be the same when using MTP speculative decoding.
|
||||||
|
model_setup: (method, model_name, tp_size)
|
||||||
|
'''
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
|
|
||||||
|
method, model_name, tp_size = model_setup
|
||||||
|
|
||||||
|
ref_llm = LLM(model=model_name,
|
||||||
|
max_model_len=2048,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
trust_remote_code=True)
|
||||||
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
spec_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
speculative_config={
|
||||||
|
"method": method,
|
||||||
|
"num_speculative_tokens": 1,
|
||||||
|
"max_model_len": 2048,
|
||||||
|
},
|
||||||
|
max_model_len=2048,
|
||||||
|
)
|
||||||
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
|
matches += 1
|
||||||
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
|
# Heuristic: expect at least 80% of the prompts to match exactly
|
||||||
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
|
||||||
|
del spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
195
tests/v1/spec_decode/test_mtp.py
Normal file
195
tests/v1/spec_decode/test_mtp.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||||
|
create_common_attn_metadata,
|
||||||
|
create_standard_kv_cache_spec,
|
||||||
|
get_attention_backend)
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
|
VllmConfig)
|
||||||
|
from vllm.config.load import LoadConfig
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||||
|
|
||||||
|
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||||
|
"""Create an MTP proposer with unified model configuration."""
|
||||||
|
model_config = ModelConfig(model=mimo_7b_dir,
|
||||||
|
runner="generate",
|
||||||
|
max_model_len=100,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
target_model_config=model_config,
|
||||||
|
target_parallel_config=ParallelConfig(),
|
||||||
|
model=mimo_7b_dir,
|
||||||
|
method="mtp",
|
||||||
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=CacheConfig(),
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
|
parallel_config=ParallelConfig(),
|
||||||
|
load_config=LoadConfig(),
|
||||||
|
scheduler_config=SchedulerConfig())
|
||||||
|
|
||||||
|
return EagleProposer(vllm_config=vllm_config,
|
||||||
|
device=current_platform.device_type)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||||
|
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||||
|
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||||
|
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||||
|
mock_get_pp_group):
|
||||||
|
"""Test MTP-specific model loading with unified model approach."""
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_model = mock.MagicMock()
|
||||||
|
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||||
|
mock_get_model.return_value = mock_model
|
||||||
|
|
||||||
|
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||||
|
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||||
|
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||||
|
|
||||||
|
mock_pp_group = mock.MagicMock()
|
||||||
|
mock_pp_group.world_size = 1
|
||||||
|
mock_get_pp_group.return_value = mock_pp_group
|
||||||
|
|
||||||
|
# Create target model
|
||||||
|
class _TargetModelStub(LlamaForCausalLM):
|
||||||
|
model: mock.MagicMock
|
||||||
|
lm_head: mock.MagicMock
|
||||||
|
|
||||||
|
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||||
|
target_model.model = mock.MagicMock()
|
||||||
|
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||||
|
target_model.lm_head = mock.MagicMock()
|
||||||
|
|
||||||
|
# Create MTP proposer
|
||||||
|
proposer = _create_mtp_proposer(num_speculative_tokens=4)
|
||||||
|
proposer.load_model(target_model)
|
||||||
|
|
||||||
|
# Verify MTP-specific behavior:
|
||||||
|
# Model is loaded
|
||||||
|
mock_get_model.assert_called_once()
|
||||||
|
# MTP shares lm_head with target model
|
||||||
|
assert proposer.model.lm_head == target_model.lm_head
|
||||||
|
# MTP shares embed_tokens with target model
|
||||||
|
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||||
|
def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||||
|
"""Test that MTP's forward method returns hidden states directly"""
|
||||||
|
|
||||||
|
device = torch.device(current_platform.device_type)
|
||||||
|
batch_size = 2
|
||||||
|
seq_lens = [5, 3]
|
||||||
|
total_tokens = sum(seq_lens)
|
||||||
|
vocab_size = 100
|
||||||
|
|
||||||
|
proposer = _create_mtp_proposer(num_speculative_tokens)
|
||||||
|
hidden_size = proposer.hidden_size
|
||||||
|
|
||||||
|
# Mock the MTP model to verify it returns hidden states directly
|
||||||
|
model_mock = mock.MagicMock()
|
||||||
|
|
||||||
|
# MTP returns hidden states directly
|
||||||
|
if num_speculative_tokens == 1:
|
||||||
|
model_mock.return_value = torch.zeros(total_tokens,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
# Multiple forward passes for multi-token speculation
|
||||||
|
forward_returns = []
|
||||||
|
for i in range(num_speculative_tokens):
|
||||||
|
if i == 0:
|
||||||
|
h_states = torch.zeros(total_tokens,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||||
|
forward_returns.append(h_states)
|
||||||
|
model_mock.side_effect = forward_returns
|
||||||
|
|
||||||
|
# Mock compute_logits
|
||||||
|
def create_deterministic_logits(batch_size, vocab_size, token_offset):
|
||||||
|
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||||
|
logits[:, token_offset] = 100.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
if num_speculative_tokens == 1:
|
||||||
|
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||||
|
batch_size, vocab_size, 42)
|
||||||
|
else:
|
||||||
|
logits_returns = [
|
||||||
|
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||||
|
for i in range(num_speculative_tokens)
|
||||||
|
]
|
||||||
|
model_mock.compute_logits.side_effect = logits_returns
|
||||||
|
|
||||||
|
proposer.model = model_mock
|
||||||
|
proposer.attn_layer_names = ["layer.0"]
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||||
|
common_attn_metadata = create_common_attn_metadata(batch_spec,
|
||||||
|
block_size=16,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
target_token_ids = torch.randint(0,
|
||||||
|
vocab_size, (total_tokens, ),
|
||||||
|
device=device)
|
||||||
|
target_positions = torch.cat([
|
||||||
|
torch.arange(seq_lens[0], device=device),
|
||||||
|
torch.arange(seq_lens[1], device=device)
|
||||||
|
])
|
||||||
|
target_hidden_states = torch.randn(total_tokens,
|
||||||
|
hidden_size,
|
||||||
|
device=device)
|
||||||
|
next_token_ids = torch.randint(0,
|
||||||
|
vocab_size, (batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
sampling_metadata = mock.MagicMock()
|
||||||
|
|
||||||
|
# Setup attention metadata
|
||||||
|
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
|
||||||
|
|
||||||
|
attn_metadata_builder = attn_metadata_builder_cls(
|
||||||
|
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||||
|
layer_names=proposer.attn_layer_names,
|
||||||
|
vllm_config=proposer.vllm_config,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
proposer.runner = mock.MagicMock()
|
||||||
|
proposer.attn_metadata_builder = attn_metadata_builder
|
||||||
|
|
||||||
|
# Run propose
|
||||||
|
result = proposer.propose(target_token_ids=target_token_ids,
|
||||||
|
target_positions=target_positions,
|
||||||
|
target_hidden_states=target_hidden_states,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
last_token_indices=None,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
# Verify the model was called correctly
|
||||||
|
assert model_mock.called
|
||||||
|
# Verify output shape
|
||||||
|
assert result.shape == (batch_size, num_speculative_tokens)
|
||||||
@ -32,7 +32,9 @@ logger = init_logger(__name__)
|
|||||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||||
"longcat_flash_mtp"]
|
"longcat_flash_mtp", "mtp"]
|
||||||
|
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
|
||||||
|
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -207,11 +209,16 @@ class SpeculativeConfig:
|
|||||||
# can not be detected, it will be considered as the "draft_model" by
|
# can not be detected, it will be considered as the "draft_model" by
|
||||||
# default.
|
# default.
|
||||||
|
|
||||||
|
if self.method in MTP_MODEL_TYPES:
|
||||||
|
logger.warning("method `%s` is deprecated and replaced with mtp.",
|
||||||
|
self.method)
|
||||||
|
self.method = "mtp"
|
||||||
|
|
||||||
if self.model is None and self.num_speculative_tokens is not None:
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
if self.method == "mtp":
|
||||||
if (self.target_model_config
|
assert (
|
||||||
and self.target_model_config.hf_text_config.model_type
|
self.target_model_config
|
||||||
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
|
is not None), "target_model_config must be present for mtp"
|
||||||
# use the draft model from the same model:
|
# use the draft model from the same model:
|
||||||
self.model = self.target_model_config.model
|
self.model = self.target_model_config.model
|
||||||
# Align the quantization of draft model for cases such as
|
# Align the quantization of draft model for cases such as
|
||||||
@ -314,31 +321,13 @@ class SpeculativeConfig:
|
|||||||
"mlp_speculator"):
|
"mlp_speculator"):
|
||||||
self.method = "mlp_speculator"
|
self.method = "mlp_speculator"
|
||||||
elif (self.draft_model_config.hf_config.model_type
|
elif (self.draft_model_config.hf_config.model_type
|
||||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
in MTP_MODEL_TYPES):
|
||||||
self.method = "deepseek_mtp"
|
self.method = "mtp"
|
||||||
if self.num_speculative_tokens > 1:
|
if self.num_speculative_tokens > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"All Deepseek MTP models only have " \
|
"Enabling num_speculative_tokens > 1 will run" \
|
||||||
"one layer. Might need some code changes " \
|
"multiple times of forward on same MTP layer" \
|
||||||
"to support multiple layers."
|
",which may result in lower acceptance rate" \
|
||||||
)
|
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
|
||||||
"ernie_mtp"):
|
|
||||||
self.method = "ernie_mtp"
|
|
||||||
if self.num_speculative_tokens > 1:
|
|
||||||
logger.warning(
|
|
||||||
"All Ernie MTP models only have " \
|
|
||||||
"one layer. Might need some code changes " \
|
|
||||||
"to support multiple layers."
|
|
||||||
)
|
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
|
||||||
"qwen3_next_mtp"):
|
|
||||||
self.method = "qwen3_next_mtp"
|
|
||||||
if self.num_speculative_tokens > 1:
|
|
||||||
logger.warning(
|
|
||||||
"All Qwen3Next MTP models only have " \
|
|
||||||
"one layer. Might need some code changes " \
|
|
||||||
"to support multiple layers."
|
|
||||||
)
|
)
|
||||||
elif (self.draft_model_config.hf_config.model_type
|
elif (self.draft_model_config.hf_config.model_type
|
||||||
in ("longcat_flash_mtp")):
|
in ("longcat_flash_mtp")):
|
||||||
@ -355,7 +344,7 @@ class SpeculativeConfig:
|
|||||||
"Speculative decoding with draft model is not "
|
"Speculative decoding with draft model is not "
|
||||||
"supported yet. Please consider using other "
|
"supported yet. Please consider using other "
|
||||||
"speculative decoding methods such as ngram, medusa, "
|
"speculative decoding methods such as ngram, medusa, "
|
||||||
"eagle, or deepseek_mtp.")
|
"eagle, or mtp.")
|
||||||
|
|
||||||
# Replace hf_config for EAGLE draft_model
|
# Replace hf_config for EAGLE draft_model
|
||||||
if self.method in ("eagle", "eagle3"):
|
if self.method in ("eagle", "eagle3"):
|
||||||
@ -564,8 +553,7 @@ class SpeculativeConfig:
|
|||||||
return self.num_speculative_tokens
|
return self.num_speculative_tokens
|
||||||
|
|
||||||
def use_eagle(self) -> bool:
|
def use_eagle(self) -> bool:
|
||||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
return self.method in ("eagle", "eagle3", "mtp")
|
||||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
method = self.method
|
method = self.method
|
||||||
|
|||||||
@ -1486,7 +1486,7 @@ class EngineArgs:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Draft model speculative decoding is not supported yet. "
|
"Draft model speculative decoding is not supported yet. "
|
||||||
"Please consider using other speculative decoding methods "
|
"Please consider using other speculative decoding methods "
|
||||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
"such as ngram, medusa, eagle, or mtp.")
|
||||||
|
|
||||||
V1_BACKENDS = [
|
V1_BACKENDS = [
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
|
|||||||
@ -235,8 +235,7 @@ class EagleProposer:
|
|||||||
hidden_states=self.hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
|
if self.method == "mtp":
|
||||||
"longcat_flash_mtp"):
|
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
hidden_states = last_hidden_states
|
hidden_states = last_hidden_states
|
||||||
else:
|
else:
|
||||||
@ -365,8 +364,7 @@ class EagleProposer:
|
|||||||
hidden_states=self.hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method in ("deepseek_mtp", "ernie_mtp",
|
if self.method == "mtp":
|
||||||
"qwen3_next_mtp", "longcat_flash_mtp"):
|
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
hidden_states = ret_hidden_states
|
hidden_states = ret_hidden_states
|
||||||
else:
|
else:
|
||||||
@ -922,10 +920,10 @@ class EagleProposer:
|
|||||||
def _get_attention_metadata_builder(
|
def _get_attention_metadata_builder(
|
||||||
self) -> list[AttentionMetadataBuilder]:
|
self) -> list[AttentionMetadataBuilder]:
|
||||||
"""Find and return the attention metadata builders for EAGLE layers.
|
"""Find and return the attention metadata builders for EAGLE layers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The metadata builders for EAGLE layers.
|
The metadata builders for EAGLE layers.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user