多次运用集成函数处理蛋白质特征数据

wrap_ensemble_fn函数实现对蛋白质特征的多次集成函数操作(如多次抽样sample_msa等)

主要函数

tree.map_structure:

fn_output_signature = tree.map_structure( tf.TensorSpec.from_tensor, tensors_0)

复制代码
tf.map_fn:用于在张量的每个元素上应用一个函数。
tensors = tf.map_fn(
    lambda x: wrap_ensemble_fn(tensors, x),
    tf.range(num_ensemble),
    parallel_iterations=1,
    fn_output_signature=fn_output_signature)
复制代码
import copy
import tensorflow.compat.v1 as tf
import tree
import pickle
import numpy as np
import ml_collections
 
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
 
CONFIG = ml_collections.ConfigDict({
    'data': {
        'common': {
            'masked_msa': {
                'profile_prob': 0.1,
                'same_prob': 0.1,
                'uniform_prob': 0.1
            },
            'max_extra_msa': 1024,
            'msa_cluster_features': True,
            'num_recycle': 3,
            'reduce_msa_clusters_by_max_templates': False,
            'resample_msa_in_recycling': True,
            'template_features': [
                'template_all_atom_positions', 'template_sum_probs',
                'template_aatype', 'template_all_atom_masks',
                'template_domain_names'
            ],
            'unsupervised_features': [
                'aatype', 'residue_index', 'sequence', 'msa', 'domain_name',
                'num_alignments', 'seq_length', 'between_segment_residues',
                'deletion_matrix'
            ],
            'use_templates': False,
        },
        'eval': {
            'feat': {
                'aatype': [NUM_RES],
                'all_atom_mask': [NUM_RES, None],
                'all_atom_positions': [NUM_RES, None, None],
                'alt_chi_angles': [NUM_RES, None],
                'atom14_alt_gt_exists': [NUM_RES, None],
                'atom14_alt_gt_positions': [NUM_RES, None, None],
                'atom14_atom_exists': [NUM_RES, None],
                'atom14_atom_is_ambiguous': [NUM_RES, None],
                'atom14_gt_exists': [NUM_RES, None],
                'atom14_gt_positions': [NUM_RES, None, None],
                'atom37_atom_exists': [NUM_RES, None],
                'backbone_affine_mask': [NUM_RES],
                'backbone_affine_tensor': [NUM_RES, None],
                'bert_mask': [NUM_MSA_SEQ, NUM_RES],
                'chi_angles': [NUM_RES, None],
                'chi_mask': [NUM_RES, None],
                'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
                'extra_msa_row_mask': [NUM_EXTRA_SEQ],
                'is_distillation': [],
                'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
                'msa_mask': [NUM_MSA_SEQ, NUM_RES],
                'msa_row_mask': [NUM_MSA_SEQ],
                'pseudo_beta': [NUM_RES, None],
                'pseudo_beta_mask': [NUM_RES],
                'random_crop_to_size_seed': [None],
                'residue_index': [NUM_RES],
                'residx_atom14_to_atom37': [NUM_RES, None],
                'residx_atom37_to_atom14': [NUM_RES, None],
                'resolution': [],
                'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
                'rigidgroups_group_exists': [NUM_RES, None],
                'rigidgroups_group_is_ambiguous': [NUM_RES, None],
                'rigidgroups_gt_exists': [NUM_RES, None],
                'rigidgroups_gt_frames': [NUM_RES, None, None],
                'seq_length': [],
                'seq_mask': [NUM_RES],
                'target_feat': [NUM_RES, None],
                'template_aatype': [NUM_TEMPLATES, NUM_RES],
                'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
                'template_all_atom_positions': [
                    NUM_TEMPLATES, NUM_RES, None, None],
                'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
                'template_backbone_affine_tensor': [
                    NUM_TEMPLATES, NUM_RES, None],
                'template_mask': [NUM_TEMPLATES],
                'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
                'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
                'template_sum_probs': [NUM_TEMPLATES, None],
                'true_msa': [NUM_MSA_SEQ, NUM_RES]
            },
            'fixed_size': True,
            'subsample_templates': True,  # We want top templates.
            'masked_msa_replace_fraction': 0.15,
            'max_msa_clusters': 512,
            'max_templates': 4,
            'num_ensemble': 1,
            'crop_size': 100,
        },
    },
    'model': {
        'embeddings_and_evoformer': {
            'evoformer_num_block': 48,
            'evoformer': {
                'msa_row_attention_with_pair_bias': {
                    'dropout_rate': 0.15,
                    'gating': True,
                    'num_head': 8,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'msa_column_attention': {
                    'dropout_rate': 0.0,
                    'gating': True,
                    'num_head': 8,
                    'orientation': 'per_column',
                    'shared_dropout': True
                },
                'msa_transition': {
                    'dropout_rate': 0.0,
                    'num_intermediate_factor': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'outer_product_mean': {
                    'first': False,
                    'chunk_size': 128,
                    'dropout_rate': 0.0,
                    'num_outer_channel': 32,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'triangle_attention_starting_node': {
                    'dropout_rate': 0.25,
                    'gating': True,
                    'num_head': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                },
                'triangle_attention_ending_node': {
                    'dropout_rate': 0.25,
                    'gating': True,
                    'num_head': 4,
                    'orientation': 'per_column',
                    'shared_dropout': True
                },
                'triangle_multiplication_outgoing': {
                    'dropout_rate': 0.25,
                    'equation': 'ikc,jkc->ijc',
                    'num_intermediate_channel': 128,
                    'orientation': 'per_row',
                    'shared_dropout': True,
                    'fuse_projection_weights': False,
                },
                'triangle_multiplication_incoming': {
                    'dropout_rate': 0.25,
                    'equation': 'kjc,kic->ijc',
                    'num_intermediate_channel': 128,
                    'orientation': 'per_row',
                    'shared_dropout': True,
                    'fuse_projection_weights': False,
                },
                'pair_transition': {
                    'dropout_rate': 0.0,
                    'num_intermediate_factor': 4,
                    'orientation': 'per_row',
                    'shared_dropout': True
                }
            },
            'extra_msa_channel': 64,
            'extra_msa_stack_num_block': 4,
            'max_relative_feature': 32,
            'msa_channel': 256,
            'pair_channel': 128,
            'prev_pos': {
                'min_bin': 3.25,
                'max_bin': 20.75,
                'num_bins': 15
            },
            'recycle_features': True,
            'recycle_pos': True,
            'seq_channel': 384,
            'template': {
                'attention': {
                    'gating': False,
                    'key_dim': 64,
                    'num_head': 4,
                    'value_dim': 64
                },
                'dgram_features': {
                    'min_bin': 3.25,
                    'max_bin': 50.75,
                    'num_bins': 39
                },
                'embed_torsion_angles': False,
                'enabled': False,
                'template_pair_stack': {
                    'num_block': 2,
                    'triangle_attention_starting_node': {
                        'dropout_rate': 0.25,
                        'gating': True,
                        'key_dim': 64,
                        'num_head': 4,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'value_dim': 64
                    },
                    'triangle_attention_ending_node': {
                        'dropout_rate': 0.25,
                        'gating': True,
                        'key_dim': 64,
                        'num_head': 4,
                        'orientation': 'per_column',
                        'shared_dropout': True,
                        'value_dim': 64
                    },
                    'triangle_multiplication_outgoing': {
                        'dropout_rate': 0.25,
                        'equation': 'ikc,jkc->ijc',
                        'num_intermediate_channel': 64,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'fuse_projection_weights': False,
                    },
                    'triangle_multiplication_incoming': {
                        'dropout_rate': 0.25,
                        'equation': 'kjc,kic->ijc',
                        'num_intermediate_channel': 64,
                        'orientation': 'per_row',
                        'shared_dropout': True,
                        'fuse_projection_weights': False,
                    },
                    'pair_transition': {
                        'dropout_rate': 0.0,
                        'num_intermediate_factor': 2,
                        'orientation': 'per_row',
                        'shared_dropout': True
                    }
                },
                'max_templates': 4,
                'subbatch_size': 128,
                'use_template_unit_vector': False,
            }
        },
        'global_config': {
            'deterministic': False,
            'multimer_mode': False,
            'subbatch_size': 4,
            'use_remat': False,
            'zero_init': True,
            'eval_dropout': False,
        },
        'heads': {
            'distogram': {
                'first_break': 2.3125,
                'last_break': 21.6875,
                'num_bins': 64,
                'weight': 0.3
            },
            'predicted_aligned_error': {
                # `num_bins - 1` bins uniformly space the
                # [0, max_error_bin A] range.
                # The final bin covers [max_error_bin A, +infty]
                # 31A gives bins with 0.5A width.
                'max_error_bin': 31.,
                'num_bins': 64,
                'num_channels': 128,
                'filter_by_resolution': True,
                'min_resolution': 0.1,
                'max_resolution': 3.0,
                'weight': 0.0,
            },
            'experimentally_resolved': {
                'filter_by_resolution': True,
                'max_resolution': 3.0,
                'min_resolution': 0.1,
                'weight': 0.01
            },
            'structure_module': {
                'num_layer': 8,
                'fape': {
                    'clamp_distance': 10.0,
                    'clamp_type': 'relu',
                    'loss_unit_distance': 10.0
                },
                'angle_norm_weight': 0.01,
                'chi_weight': 0.5,
                'clash_overlap_tolerance': 1.5,
                'compute_in_graph_metrics': True,
                'dropout': 0.1,
                'num_channel': 384,
                'num_head': 12,
                'num_layer_in_transition': 3,
                'num_point_qk': 4,
                'num_point_v': 8,
                'num_scalar_qk': 16,
                'num_scalar_v': 16,
                'position_scale': 10.0,
                'sidechain': {
                    'atom_clamp_distance': 10.0,
                    'num_channel': 128,
                    'num_residual_block': 2,
                    'weight_frac': 0.5,
                    'length_scale': 10.,
                },
                'structural_violation_loss_weight': 1.0,
                'violation_tolerance_factor': 12.0,
                'weight': 1.0
            },
            'predicted_lddt': {
                'filter_by_resolution': True,
                'max_resolution': 3.0,
                'min_resolution': 0.1,
                'num_bins': 50,
                'num_channels': 128,
                'weight': 0.01
            },
            'masked_msa': {
                'num_output': 23,
                'weight': 2.0
            },
        },
        'num_recycle': 3,
        'resample_msa_in_recycling': True
    },
})
 
 
_MSA_FEATURE_NAMES = [
    'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
    'true_msa'
]
 
 
class SeedMaker(object):
  """Return unique seeds."""
 
  def __init__(self, initial_seed=0):
    self.next_seed = initial_seed
 
  def __call__(self):
    i = self.next_seed
    self.next_seed += 1
    return i
 
 
def shape_list(x):
  """Return list of dimensions of a tensor, statically where possible.
  Like `x.shape.as_list()` but with tensors instead of `None`s.
  Args:
    x: A tensor.
  Returns:
    A list with length equal to the rank of the tensor. The n-th element of the
    list is an integer when that dimension is statically known otherwise it is
    the n-th element of `tf.shape(x)`.
  """
  x = tf.convert_to_tensor(x)
 
  # If unknown rank, return dynamic shape
  if x.get_shape().dims is None:
    return tf.shape(x)
 
  static = x.get_shape().as_list()
  shape = tf.shape(x)
 
  ret = []
  for i in range(len(static)):
    dim = static[i]
    if dim is None:
      dim = shape[i]
    ret.append(dim)
  return ret
 
 
def shaped_categorical(probs, epsilon=1e-10):
  ds = shape_list(probs)
  num_classes = ds[-1]
  counts = tf.random.categorical(
      tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
      1,
      dtype=tf.int32)
  return tf.reshape(counts, ds[:-1])
 
 
def data_transforms_curry1(f):
  """Supply all arguments but the first."""
 
  def fc(*args, **kwargs):
    return lambda x: f(x, *args, **kwargs)
 
  return fc
 
 
 
@data_transforms_curry1
def sample_msa(protein, max_seq, keep_extra):
  """Sample MSA randomly, remaining sequences are stored as `extra_*`.
  Args:
    protein: batch to sample msa from.
    max_seq: number of sequences to sample.
    keep_extra: When True sequences not sampled are put into fields starting
      with 'extra_*'.
  Returns:
    Protein with sampled msa.
  """
  num_seq = tf.shape(protein['msa'])[0]
  # 索引0的序列为查询序列
  shuffled = tf.random_shuffle(tf.range(1, num_seq))
  index_order = tf.concat([[0], shuffled], axis=0)
  num_sel = tf.minimum(max_seq, num_seq)
  # tf.split函数将张量沿指定轴进行切分,
  # 第一张量大小为num_sel,第二张量大小为num_seq - num_sel
  sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])
 
  for k in _MSA_FEATURE_NAMES:
    if k in protein:
      if keep_extra:
        # tf.gather 按索引从输入张量中收集元素的函数
          protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
      protein[k] = tf.gather(protein[k], sel_seq)
 
  return protein
 
 
@data_transforms_curry1
def make_masked_msa(protein, config, replace_fraction):
  """Create data for BERT on raw MSA."""
  # Add a random amino acid uniformly
  random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)
  # 构建随机随机出现某一氨基酸的概率,和MSA中氨基酸的保守性有关
  categorical_probs = (
      config.uniform_prob * random_aa +
      config.profile_prob * protein['hhblits_profile'] +
      config.same_prob * tf.one_hot(protein['msa'], 22))
 
  #print(tf.reduce_sum(categorical_probs, axis=-1))  # 都为0.3
 
  # Put all remaining probability on [MASK] which is a new column
 
  pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
  pad_shapes[-1][1] = 1
  # mask_prob : 0.7, 其他prob加在一起0.3
  mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
  assert mask_prob >= 0.
  # categorical_probs张量后填充mask_prob值,代表MSA每一个位置的概率(20种氨基酸+gap+X+mask)
  categorical_probs = tf.pad(
      categorical_probs, pad_shapes, constant_values=mask_prob)
 
  #print(tf.reduce_sum(categorical_probs, axis=-1))  # 都为0.3
 
  sh = shape_list(protein['msa'])
  # 0-1均匀分布中随机抽样,形状为sh,通过和replace_fraction(0.15)比较,产生随机mask位置
  mask_position = tf.random.uniform(sh) < replace_fraction
  
  ##抽样,注意随机性产生的方式,抽到mask概率最大,而抽到其他氨基酸概率的大小和其在MSA中的保守性有关
  bert_msa = shaped_categorical(categorical_probs)
  ## 大概0.15的概率用随机氨基酸代替,随机氨基酸中有0.7的概率是mask,还有0.3的概率抽到其他氨基酸,
  ## 氨基酸在此位置越保守,抽到的可能性越大
  ## bert_msa中大概有0.7*0.15的mask,还有混杂着错误和正确的氨基酸
  bert_msa = tf.where(mask_position, bert_msa, protein['msa'])
 
  # Mix real and masked MSA
  protein['bert_mask'] = tf.cast(mask_position, tf.float32)
  protein['true_msa'] = protein['msa']
  protein['msa'] = bert_msa
 
  return protein
 
 
@data_transforms_curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
  """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
 
  # Determine how much weight we assign to each agreement.  In theory, we could
  # use a full blosum matrix here, but right now let's just down-weight gap
  # agreement because it could be spurious.
  # Never put weight on agreeing on BERT mask
  # 除了gap权重为0,其他(restype+X+mask)权重为1
  weights = tf.concat([
      tf.ones(21),
      gap_agreement_weight * tf.ones(1),
      np.zeros(1)], 0)
 
  # Make agreement score as weighted Hamming distance
  # 增加一个维度
  sample_one_hot = (protein['msa_mask'][:, :, None] *
                    tf.one_hot(protein['msa'], 23))
  extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
                   tf.one_hot(protein['extra_msa'], 23))
 
  num_seq, num_res, _ = shape_list(sample_one_hot)
  extra_num_seq, _, _ = shape_list(extra_one_hot)
 
  # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
  # in an optimized fashion to avoid possible memory or computation blowup.
  # 判断extra msa序列与MSA sample序列的相似度,相同的氨基酸越多,越相似。
  # 没有考虑氨基酸的性质,可以改进
  # 注意氨基酸的权重(weights)
  agreement = tf.matmul(
      tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
      tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
      transpose_b=True)
 
  # Assign each sequence in the extra sequences to the closest MSA sample
  # 对extra msa中每一条序列,取相似度最高的MSA sample序列
  protein['extra_cluster_assignment'] = tf.argmax(
      agreement, axis=1, output_type=tf.int32)
 
  return protein
@data_transforms_curry1
def summarize_clusters(protein):
  """Produce profile and deletion_matrix_mean within each cluster."""
  num_seq = shape_list(protein['msa'])[0]
  def csum(x):
    return tf.math.unsorted_segment_sum(
        x, protein['extra_cluster_assignment'], num_seq)
 
  mask = protein['extra_msa_mask']
  mask_counts = 1e-6 + protein['msa_mask'] + csum(mask)  # Include center
  
  # 结果张量[num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa之和,以此类推
  msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
  msa_sum += tf.one_hot(protein['msa'], 23)  # Original sequence
  protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
 
  del msa_sum
 
  # 每条msa序列的最近邻序列的extr_msa,在不同位置deletion数统计
  # del_sum [num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa,不同位置deletion数,以此类推
  del_sum = csum(mask * protein['extra_deletion_matrix'])
  del_sum += protein['deletion_matrix']  # Original sequence
  protein['cluster_deletion_mean'] = del_sum / mask_counts
  del del_sum
 
  return protein
@data_transforms_curry1
def crop_extra_msa(protein, max_extra_msa):
  """MSA features are cropped so only `max_extra_msa` sequences are kept."""
  num_seq = tf.shape(protein['extra_msa'])[0]
  num_sel = tf.minimum(max_extra_msa, num_seq)
  select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
  for k in _MSA_FEATURE_NAMES:
    if 'extra_' + k in protein:
      protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)
  return protein
@data_transforms_curry1
def make_msa_feat(protein):
  """Create and concatenate MSA features."""
  # Whether there is a domain break. Always zero for chains, but keeping
  # for compatibility with domain datasets.
  has_break = tf.clip_by_value(
      tf.cast(protein['between_segment_residues'], tf.float32),
      0, 1)
  aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)
  target_feat = [
      tf.expand_dims(has_break, axis=-1),
      aatype_1hot,  # Everyone gets the original sequence.
  ]
  msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
  has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
  deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
  msa_feat = [
      msa_1hot,
      tf.expand_dims(has_deletion, axis=-1),
      tf.expand_dims(deletion_value, axis=-1),
  ]
  if 'cluster_profile' in protein:
    deletion_mean_value = (
        tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
    msa_feat.extend([
        protein['cluster_profile'],
        tf.expand_dims(deletion_mean_value, axis=-1),
    ])
  if 'extra_deletion_matrix' in protein:
    protein['extra_has_deletion'] = tf.clip_by_value(
        protein['extra_deletion_matrix'], 0., 1.)
    protein['extra_deletion_value'] = tf.atan(
        protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)
  protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
  protein['target_feat'] = tf.concat(target_feat, axis=-1)
  return protein
@data_transforms_curry1
def select_feat(protein, feature_list):
  return {k: v for k, v in protein.items() if k in feature_list}
@data_transforms_curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
                        subsample_templates=False):
  """Crop randomly to `crop_size`, or keep as is if shorter than that."""
  seq_length = protein['seq_length']
  if 'template_mask' in protein:
    num_templates = tf.cast(
        shape_list(protein['template_mask'])[0], tf.int32)
  else:
    num_templates = tf.constant(0, dtype=tf.int32)
  num_res_crop_size = tf.math.minimum(seq_length, crop_size)
 
  # Ensures that the cropping of residues and templates happens in the same way
  # across ensembling iterations.
  # Do not use for randomness that should vary in ensembling.
  seed_maker = SeedMaker(initial_seed=protein['random_crop_to_size_seed'])
 
  if subsample_templates:
    templates_crop_start = tf.random.stateless_uniform(
        shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
        seed=seed_maker())
  else:
    templates_crop_start = 0
 
  num_templates_crop_size = tf.math.minimum(
      num_templates - templates_crop_start, max_templates)
 
  num_res_crop_start = tf.random.stateless_uniform(
      shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
      dtype=tf.int32, seed=seed_maker())
 
  ## 产生随机打乱的索引,用于所有需要裁剪的模版特征
 
  # tf.argsort 函数用于返回张量中元素的排序索引
  # tf.random.stateless_uniform:生成指定形状的服从均匀分布的随机张量
  # 生成num_templates个指定形状的服从均匀分布的随机张量,形状为shape=(num_templates,)。
  # 注:num_templates为标量,作为shape时,变成list[num_templates]
  templates_select_indices = tf.argsort(tf.random.stateless_uniform(
      [num_templates], seed=seed_maker()))
 
  for k, v in protein.items():
    if k not in shape_schema or (
        'template' not in k and NUM_RES not in shape_schema[k]):
      continue
 
    # randomly permute the templates before cropping them.
    if k.startswith('template') and subsample_templates:
      v = tf.gather(v, templates_select_indices)
 
    crop_sizes = []
    crop_starts = []
    
    # zip函数把维度说明和维度值绑定
    # shape_schema[k]维度说明(placeholder)列表 ,shape_list(v)维度值
    for i, (dim_size, dim) in enumerate(zip(shape_schema[k],shape_list(v))):
      is_num_res = (dim_size == NUM_RES)
      if i == 0 and k.startswith('template'):
        crop_size = num_templates_crop_size
        crop_start = templates_crop_start
      else:
        crop_start = num_res_crop_start if is_num_res else 0
        crop_size = (num_res_crop_size if is_num_res else
                     (-1 if dim is None else dim))
      crop_sizes.append(crop_size)
      crop_starts.append(crop_start)
    protein[k] = tf.slice(v, crop_starts, crop_sizes)
 
  protein['seq_length'] = num_res_crop_size
  return protein
@data_transforms_curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
                    num_res, num_templates=0):
  """Guess at the MSA and sequence dimensions to make fixed size."""
 
  pad_size_map = {
      NUM_RES: num_res,
      NUM_MSA_SEQ: msa_cluster_size,
      NUM_EXTRA_SEQ: extra_msa_size,
      NUM_TEMPLATES: num_templates,
  }
 
  for k, v in protein.items():
    # Don't transfer this to the accelerator.
    if k == 'extra_cluster_assignment':
      continue
    shape = v.shape.as_list()
    # 特征维度placeholder
    schema = shape_schema[k]
    assert len(shape) == len(schema), (
        f'Rank mismatch between shape and shape schema for {k}: '
        f'{shape} vs {schema}')
    
    # 特征张量不同维度的填充尺寸(pad_size)。需要填充的维度尺寸由pad_size_map决定。
    # 字典get方法,键不存在时返回的None,这时列表取 s1 for (s1, s2) in zip(shape, schema)
    pad_size = [
        pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
    ]
    # 在张量的后面填充,需要填充0的数目为填充尺寸减去现有的尺寸(p - tf.shape(v)[i])
    padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
    if padding:
      protein[k] = tf.pad(
          v, padding, name=f'pad_to_fixed_{k}')
      protein[k].set_shape(pad_size)
  return protein
 
 
def ensembled_map_fns(data_config):
  """Input pipeline functions that can be ensembled and averaged."""
  common_cfg = data_config.common
  eval_cfg = data_config.eval
 
  map_fns = []
 
  if common_cfg.reduce_msa_clusters_by_max_templates:
    pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
  else:
    pad_msa_clusters = eval_cfg.max_msa_clusters
 
  max_msa_clusters = pad_msa_clusters
  max_extra_msa = common_cfg.max_extra_msa
 
  map_fns.append(sample_msa(max_msa_clusters,keep_extra=True))
 
  if 'masked_msa' in common_cfg:
    # Masked MSA should come *before* MSA clustering so that
    # the clustering and full MSA profile do not leak information about
    # the masked locations and secret corrupted locations.
    map_fns.append(make_masked_msa(common_cfg.masked_msa,
                                   eval_cfg.masked_msa_replace_fraction))
 
  if common_cfg.msa_cluster_features:
    map_fns.append(nearest_neighbor_clusters())
    
    map_fns.append(summarize_clusters())
    
  # Crop after creating the cluster profiles.
  if max_extra_msa:
    map_fns.append(crop_extra_msa(max_extra_msa))
  else:
    map_fns.append(delete_extra_msa)
 
  map_fns.append(make_msa_feat())
 
  crop_feats = dict(eval_cfg.feat)
 
  if eval_cfg.fixed_size:
    map_fns.append(select_feat(list(crop_feats)))
    map_fns.append(random_crop_to_size(
        eval_cfg.crop_size,
        eval_cfg.max_templates,
        crop_feats,
        eval_cfg.subsample_templates))
    map_fns.append(make_fixed_size(
        crop_feats,
        pad_msa_clusters,
        common_cfg.max_extra_msa,
        eval_cfg.crop_size,
        eval_cfg.max_templates))
  else:
    map_fns.append(crop_templates(eval_cfg.max_templates))
 
  return map_fns
 
 
@data_transforms_curry1
def compose(x, fs):
  for f in fs:
    x = f(x)
  return x
 

### 得到配置数据
data_config = CONFIG.data 
eval_cfg = data_config.eval
common_cfg = data_config.common
 
crop_feats = dict(eval_cfg.feat)
#pad_msa_clusters = eval_cfg.max_msa_clusters
 
shape_schema = crop_feats
num_ensemble = eval_cfg.num_ensemble


def wrap_ensemble_fn(data, i):
  """Function to be mapped over the ensemble dimension."""
  d = data.copy()
  fns = ensembled_map_fns(data_config)
  fn = compose(fns)
  d['ensemble_index'] = i
  return fn(d)


### 读入数据,蛋白质特征已经过nonensembled函数处理 
with open("Human_HBB_tensor_dict_nonensembled.pkl",'rb') as f:
   Human_HBB_tensor = pickle.load(f)
 
protein = copy.deepcopy(Human_HBB_tensor)
 
#加上protein['deletion_matrix']特征,不然会报错
protein['deletion_matrix'] = tf.cast(protein['deletion_matrix_int'], dtype=tf.float32) 
 
protein_0 = wrap_ensemble_fn(protein, tf.constant(0))

if data_config.common.resample_msa_in_recycling:
  # Separate batch per ensembling & recycling step.
  num_ensemble *= data_config.common.num_recycle + 1
 

if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
  fn_output_signature = tree.map_structure(
        tf.TensorSpec.from_tensor, protein_0)
  #tf.map_fn 在处理两个结构不具有相同嵌套结构的情况时,
  #可以使用 fn_output_signature 参数来指定输出函数的签名,
  #从而显式指定输出的结构。
  protein = tf.map_fn(
        lambda x: wrap_ensemble_fn(protein, x),
        tf.range(num_ensemble),
        parallel_iterations=1,
        fn_output_signature=fn_output_signature)
else:
  # 增加一个维度
  protein = tree.map_structure(lambda x: x[None],
                                 protein_0)
 
print(f"ensembled函数处理前:")
print(f"特征数:{len(Human_HBB_tensor)}")
print(f"特征:{Human_HBB_tensor.keys()}")
print(Human_HBB_tensor['aatype'].shape)
#print(Human_HBB_tensor['aatype'])
      
print(f"ensembled函数处理后:")
print(f"特征数:{len(protein)}")
print(f"特征:{protein.keys()}")
print(protein['extra_msa'].shape)
print(protein['aatype'].shape)
print(protein['msa_feat'].shape)

print("protein_0['msa_feat'].shape")
print(protein_0['msa_feat'].shape)
相关推荐
数据智能老司机2 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机3 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机3 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机3 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i3 小时前
drf初步梳理
python·django
每日AI新事件3 小时前
python的异步函数
python
这里有鱼汤4 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python
databook13 小时前
Manim实现脉冲闪烁特效
后端·python·动效
程序设计实验室13 小时前
2025年了,在 Django 之外,Python Web 框架还能怎么选?
python
倔强青铜三15 小时前
苦练Python第46天:文件写入与上下文管理器
人工智能·python·面试