diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index eed1b3fb4bc85..c1eb207efcd18 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -75,7 +75,12 @@ This section details the necessary modifications to make to a Transformers compa To make your model compatible with the Transformers backend, it needs: 1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. - 1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`. + - If your model is encoder-only: + 1. Add `is_causal = False` to `MyAttention`. + - If your model is mixture-of-experts (MoE): + 1. Your sparse MoE block must have an attribute called `experts`. + 2. The class of `experts` (`MyExperts`) must inherit from `nn.ModuleList`. + 3. `MyExperts.forward` must accept `hidden_states`, `top_k_index`, `top_k_weights`. 2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. 3. `MyModel` must contain `_supports_attention_backend = True`. @@ -102,6 +107,23 @@ class MyAttention(nn.Module): ) ... +# Only do this for mixture-of-experts models +class MyExperts(nn.ModuleList): + def forward(self, hidden_states, top_k_index, top_k_weights): + ... + +# Only do this for mixture-of-experts models +class MySparseMoEBlock(nn.Module): + def __init__(self, config): + ... + self.experts = MyExperts(config) + ... + + def forward(self, hidden_states: torch.Tensor): + ... + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) + ... + class MyModel(PreTrainedModel): _supports_attention_backend = True ```