mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2026-04-22 09:27:09 +08:00
69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
|
|
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
|
|
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
|
|
|
|
|
|
class T5EncoderPolicy(Policy):
|
|
def config_sanity_check(self):
|
|
assert not self.shard_config.enable_tensor_parallelism
|
|
assert not self.shard_config.enable_flash_attention
|
|
|
|
def preprocess(self):
|
|
return self.model
|
|
|
|
def module_policy(self):
|
|
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
|
|
|
|
policy = {}
|
|
|
|
# check whether apex is installed
|
|
try:
|
|
from apex.normalization import FusedRMSNorm # noqa
|
|
from videosys.core.shardformer.t5.modeling import T5LayerNorm
|
|
|
|
# recover hf from fused rms norm to T5 norm which is faster
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(
|
|
suffix="layer_norm",
|
|
target_module=T5LayerNorm,
|
|
),
|
|
policy=policy,
|
|
target_key=T5LayerFF,
|
|
)
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
|
|
policy=policy,
|
|
target_key=T5LayerSelfAttention,
|
|
)
|
|
self.append_or_create_submodule_replacement(
|
|
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
|
|
policy=policy,
|
|
target_key=T5Stack,
|
|
)
|
|
except (ImportError, ModuleNotFoundError):
|
|
pass
|
|
|
|
# use jit operator
|
|
if self.shard_config.enable_jit_fused:
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_jit_fused_T5_layer_ff_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=T5LayerFF,
|
|
)
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": get_T5_layer_self_attention_forward(),
|
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
|
},
|
|
policy=policy,
|
|
target_key=T5LayerSelfAttention,
|
|
)
|
|
|
|
return policy
|
|
|
|
def postprocess(self):
|
|
return self.model
|