TemplatePairStack是实现蛋白质结构模版pair_act特征表示的类:
通过layer_stack.layer_stack(c.num_block)(block) 堆叠c.num_block(配置文件中为2)block 函数,每个block对输入pair_act和 pair_mask执行计算流程:TriangleAttention —> dropout ->TriangleAttention —> dropout -> TriangleMultiplication —> dropout -> TriangleMultiplication —> dropout -> Transition
import haiku as hkclass TemplatePairStack(hk.Module):"""Pair stack for the templates.Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack""""def __init__(self, config, global_config, name='template_pair_stack'):super().__init__(name=name)self.config = configself.global_config = global_configdef __call__(self, pair_act, pair_mask, is_training, safe_key=None):"""Builds TemplatePairStack module.Arguments:pair_act: Pair activations for single template, shape [N_res, N_res, c_t].pair_mask: Pair mask, shape [N_res, N_res].is_training: Whether the module is in training mode.safe_key: Safe key object encapsulating the random number generation key.Returns:Updated pair_act, shape [N_res, N_res, c_t]."""if safe_key is None:safe_key = prng.SafeKey(hk.next_rng_key())gc = self.global_configc = self.configif not c.num_block:return pair_actdef block(x):"""One block of the template pair stack."""pair_act, safe_key = xdropout_wrapper_fn = functools.partial(dropout_wrapper, is_training=is_training, global_config=gc)safe_key, *sub_keys = safe_key.split(6)sub_keys = iter(sub_keys)pair_act = dropout_wrapper_fn(TriangleAttention(c.triangle_attention_starting_node, gc,name='triangle_attention_starting_node'),pair_act,pair_mask,next(sub_keys))pair_act = dropout_wrapper_fn(TriangleAttention(c.triangle_attention_ending_node, gc,name='triangle_attention_ending_node'),pair_act,pair_mask,next(sub_keys))pair_act = dropout_wrapper_fn(TriangleMultiplication(c.triangle_multiplication_outgoing, gc, name='triangle_multiplication_outgoing'),pair_act,pair_mask,next(sub_keys))pair_act = dropout_wrapper_fn(TriangleMultiplication(c.triangle_multiplication_incoming, gc, name='triangle_multiplication_incoming'),pair_act,pair_mask,next(sub_keys))pair_act = dropout_wrapper_fn(Transition(c.pair_transition, gc, name='pair_transition'),pair_act,pair_mask,next(sub_keys))return pair_act, safe_keyif gc.use_remat:block = hk.remat(block)res_stack = layer_stack.layer_stack(c.num_block)(block)pair_act, safe_key = res_stack((pair_act, safe_key))return pair_act