预训练蛋白质语言模型ESM-2保姆级使用教程

  • ESM-2(Evolutionary Scale Modeling 2)当下最先进的预训练蛋白质语言模型之一, 由Facebook AI

    Research开发,最新版本使用了48层Transformer编码器架构,有150亿参数。

  • 蛋白质语言模型(PLM)可以用于理解和预测蛋白质序列的特性,包括它们的结构、功能等。

以下是使用ESM-2的详细教程:


python 复制代码
# pip install esm
# pip install torch
# pip安装的esm不包括模型,size is small,but torch is big, more than 1G.
import torch
import esm
import os

我们先指定一个新的目录路径,然后写入环境变量字典,设置TORCH_HOME环境变量

  • 为了方便管理模型,我们可以设置TORCH_HOME环境变量,将模型下载到我们指定的目录;

  • 在运行esm.pretrained.esm2_t33_650M_UR50D()时,PyTorch会检查这个目录,如果模型esm2_t33_650M_UR50D已经存在,它会从那里加载模型,否则它会从网上下载模型并保存在这个目录下。

  • 例如:当环境变量设置为'D:\Desktop\model'时,模型的下载地址为'D:\Desktop\model\hub\checkpoints\esm2_t33_650M_UR50D.pt'

  • 设置 TORCH_HOME 环境变量后,所有 PyTorch 相关的库(比如 torch.hub 或 transformers)在下载模型和数据集时,都会使用这个目录作为下载位置。

  • 注:

    PyTorch和ESM都是Facebook的产品;

    os.environ 返回一个代表当前环境变量的字典对象。

python 复制代码
new_dir = 'D:\Desktop\model'
os.environ['TORCH_HOME'] = new_dir

下载模型到我们上面指定的目录,或者从指定的目录加载模型;alphabet代表模型使用的字母表,它定义了模型能够处理的字符集合。

  • Size of esm2_t33_650M_UR50D is very big, about 2.4G, 这个模型使用了33层Transformer编码器架构,有650百万(6.5亿参数),使用UniRef50作为训练集;

  • 关于UniRef100、UniRef90和UniRef50的知识,请参考:https://pubmed.ncbi.nlm.nih.gov/17379688/

  • esm包含好几个版本的蛋白质语言预训练模型,可以通过esm.pretrained.xxx指定使用不同版本:esm.pretrained.esm2_t36_3B_UR50D...

python 复制代码
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

构建数据格式转换器

python 复制代码
batch_converter = alphabet.get_batch_converter()

将模型设置为评估模式,这会关闭dropout等训练特有的行为

  • 在神经网络中,"dropout" 是一种正则化技术,用于防止或减少模型的过拟合。Dropout通过在训练过程中随机"丢弃"(即暂时移除)网络中的一些神经元(包括它们所有的连接),来减少神经元之间复杂的共适应关系,从而促进模型的泛化能力。

  • 在训练神经网络时启用dropout,以减少过拟合。

  • 在模型评估或预测时禁用dropout,确保所有神经元都参与工作。

python 复制代码
model.eval()

demo数据

python 复制代码
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE")
]

使用数据格式转换器将数据转换为模型可以理解/处理的tokens,并将这些tokens填充为相同的长度,然后批量计算每个序列的长度

  • batch_labels:存储每个序列的标签或ID,可能用于后续的监督学习任务;

  • batch_strs:存储原始的蛋白质序列字符串,可能用于调试或显示目的;

  • batch_tokens:是一个二维Tensor张量,一行代表一个蛋白质序列,存储转换后的tokens,这些tokens是原始序列中氨基酸的整数索引(0-20,20种氨基酸);

  • alphabet.padding_idx是一个填充矩阵;

  • batch_tokens != alphabet.padding_idx 生成一个布尔矩阵,实际有氨基酸的位置是True,填充的位置是False;

  • .sum(1) 对这个布尔矩阵沿着维度1求和,即生成每个序列的长度;

  • 这样做的目的是批量求每个序列的长度。

python 复制代码
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

提取蛋白质序列中每个氨基酸残基的 token_representations,获取模型对输入序列的深层次理解

  • with torch.no_grad(): 禁用PyTorch中的梯度计算。在模型评估或预测阶段,我们不需要进行反向传播,禁用梯度计算可以减少内存消耗并加速计算。使用with表明该操作是暂时的。

  • repr_layers=[33] 表示使用Transformer架构第33层的输出作为特征representation;

  • return_contacts=True 表示获取模型预测的氨基酸残基之间的接触图,这在蛋白质结构预测中是一个有用的特征;

  • results["representations"][33] 表示从模型输出results中提取第33层的representations;

  • results是一个字典,"representations"是这个字典的一个键(key),该键对应的值(value)也是一个字典,存放着每一层的token_representations。字典中嵌套字典。

python 复制代码
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

生成每一个序列的 representations

  • token 0是序列开始标记, 所以第一个氨基酸残基是token 1;最后一个token是序列结束标记;
  • token_representations 是一个三维Tensor张量(可以理解为三维数组);
  • token_representations[i, 1:tokens_len-1] 使用切片取出每条蛋白质序列的token_representations(二维,根据氨基酸数量取行,列全取);
  • .mean(0) 使用每条蛋白质序列的二维token_representations的列平均值作为该蛋白质序列的representations(sequence_representations)
python 复制代码
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):  # 例如,氨基酸token为65个,则batch_lens=67
    sequence_representations.append(token_representations[i, 1:tokens_len-1].mean(0))

生成使用无监督学习方法预测的蛋白质内部残基间的接触图

python 复制代码
import matplotlib.pyplot as plt
for (ID, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(ID)
    path = os.path.join(os.getcwd(), ID)
    plt.savefig(path, bbox_inches="tight")

参考:https://github.com/facebookresearch/esm

相关推荐
nfgo5 分钟前
Apollo自动驾驶项目(二:cyber框架分析)
人工智能·自动驾驶·unix
h1771134720511 分钟前
基于区块链的相亲交易系统源码解析
大数据·人工智能·安全·系统架构·交友
HPC_fac1305206781621 分钟前
RTX 4090 系列即将停产,RTX 5090 系列蓄势待发
服务器·人工智能·gpu算力
xuehaisj1 小时前
论文内容分类与检测系统源码分享
人工智能·分类·数据挖掘
大耳朵爱学习1 小时前
大模型预训练的降本增效之路——从信息密度出发
人工智能·深度学习·机器学习·自然语言处理·大模型·llm·大语言模型
loongloongz2 小时前
联合条件概率 以及在语言模型中的应用
人工智能·语言模型·自然语言处理·概率论
lijfrank2 小时前
情感计算领域期刊与会议
人工智能·人机交互
sp_fyf_20242 小时前
计算机人工智能前沿进展-大语言模型方向-2024-09-18
人工智能·语言模型·自然语言处理
sp_fyf_20242 小时前
计算机人工智能前沿进展-大语言模型方向-2024-09-14
人工智能·语言模型·自然语言处理
ybdesire2 小时前
nanoGPT用红楼梦数据从头训练babyGPT-12.32M实现任意问答
人工智能·深度学习·语言模型