haiku实现TemplatePairStack类

TemplatePairStack是实现蛋白质结构模版pair_act特征表示的类:
通过layer_stack.layer_stack(c.num_block)(block) 堆叠c.num_block(配置文件中为2)block 函数,每个block对输入pair_act 和 pair_mask执行计算流程:TriangleAttention ---> dropout ->TriangleAttention ---> dropout -> TriangleMultiplication ---> dropout -> TriangleMultiplication ---> dropout -> Transition

复制代码
import haiku as hk


class TemplatePairStack(hk.Module):
  """Pair stack for the templates.

  Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
  """

  def __init__(self, config, global_config, name='template_pair_stack'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, pair_act, pair_mask, is_training, safe_key=None):
    """Builds TemplatePairStack module.

    Arguments:
      pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
      pair_mask: Pair mask, shape [N_res, N_res].
      is_training: Whether the module is in training mode.
      safe_key: Safe key object encapsulating the random number generation key.

    Returns:
      Updated pair_act, shape [N_res, N_res, c_t].
    """

    if safe_key is None:
      safe_key = prng.SafeKey(hk.next_rng_key())

    gc = self.global_config
    c = self.config

    if not c.num_block:
      return pair_act

    def block(x):
      """One block of the template pair stack."""
      pair_act, safe_key = x

      dropout_wrapper_fn = functools.partial(
          dropout_wrapper, is_training=is_training, global_config=gc)

      safe_key, *sub_keys = safe_key.split(6)
      sub_keys = iter(sub_keys)

      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_starting_node, gc,
                            name='triangle_attention_starting_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_ending_node, gc,
                            name='triangle_attention_ending_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
                                 name='triangle_multiplication_outgoing'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_incoming, gc,
                                 name='triangle_multiplication_incoming'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          Transition(c.pair_transition, gc, name='pair_transition'),
          pair_act,
          pair_mask,
          next(sub_keys))

      return pair_act, safe_key

    if gc.use_remat:
      block = hk.remat(block)

    res_stack = layer_stack.layer_stack(c.num_block)(block)
    pair_act, safe_key = res_stack((pair_act, safe_key))
    return pair_act
相关推荐
视觉语言导航12 分钟前
具身导航视角适应性增强!VIL:连续环境视觉语言导航的视角不变学习
人工智能·机器人·具身智能
猫先生Mr.Mao12 分钟前
2025年10月AGI月评|OmniNWM/X-VLA/DreamOmni2等6大开源项目:自动驾驶、机器人、文档智能的“技术底座”全解析
人工智能·机器人·大模型·自动驾驶·agi·大模型部署·分布式推理框架
WWZZ202517 分钟前
快速上手大模型:深度学习4(实践:多层感知机)
人工智能·深度学习·计算机视觉·机器人·大模型·slam·具身智能
zhangfeng11331 小时前
移动流行区间法(MEM)的原理和与LSTM、ARIMA等时间序列方法的区别
人工智能·rnn·lstm
数字化脑洞实验室2 小时前
如何理解不同行业AI决策系统的功能差异?
大数据·人工智能·算法
一点七加一2 小时前
Harmony鸿蒙开发0基础入门到精通Day07--JavaScript篇
开发语言·javascript·ecmascript
视觉语言导航2 小时前
RAPID:基于逆强化学习的无人机视觉导航鲁棒且敏捷规划器
人工智能·无人机·具身智能
阿郎_20112 小时前
python自动化脚本-简化留言
python·自动化
TextIn智能文档云平台2 小时前
大模型文档解析技术有哪些?
人工智能
大明者省2 小时前
案例分析交叉熵和交叉验证区别和联系
人工智能·深度学习·神经网络·计算机视觉·cnn