# predict.py
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import OneHotEncoder
from model import FCNet
import pickle
import os
import sklearn
from packaging import version
# 自定义函数用于加载模型和编码器
def load_model(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件 '{model_path}' 不存在。")
# 加载模型文件
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
encoder = checkpoint['encoder']
amino_acids = checkpoint['amino_acids']
# 定义模型
feature_dim1 = 3 # 最多三个氨基酸
feature_dim2 = len(amino_acids) + 1 # one-hot + 数值特征
model = FCNet(feature_dim1, feature_dim2)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, encoder, amino_acids
# 自定义函数进行预测
def predict(model, encoder, amino_acids, single_dict, aa_input):
# 检查输入类型和长度
if not isinstance(aa_input, str):
raise ValueError("输入的氨基酸组合必须是字符串,例如 'A', 'AC', 'ACD'。")
aa_input = aa_input.upper()
length = len(aa_input)
if length == 0 or length > 3:
raise ValueError("输入的氨基酸组合长度必须在1到3之间。")
# 填充至三个氨基酸
if length == 1:
aa_triplet = aa_input + '--'
elif length == 2:
aa_triplet = aa_input + '-'
else:
aa_triplet = aa_input
# 检查是否所有氨基酸都在编码器中
for aa in aa_triplet:
if aa not in amino_acids:
raise ValueError(f"氨基酸 '{aa}' 不在编码器的氨基酸列表中。")
# 独热编码
try:
one_hot1 = encoder.transform([[aa_triplet[0]]])[0]
one_hot2 = encoder.transform([[aa_triplet[1]]])[0]
one_hot3 = encoder.transform([[aa_triplet[2]]])[0]
except Exception as e:
raise ValueError(f"无法对氨基酸组合 '{aa_triplet}' 进行编码。错误信息:{e}")
# 读取单个氨基酸的数值特征
value1_aa1 = single_dict.get(aa_triplet[0], 0.0)
value1_aa2 = single_dict.get(aa_triplet[1], 0.0)
value1_aa3 = single_dict.get(aa_triplet[2], 0.0)
# 构建特征向量
feature_part1 = np.concatenate([one_hot1, [value1_aa1]])
feature_part2 = np.concatenate([one_hot2, [value1_aa2]])
feature_part3 = np.concatenate([one_hot3, [value1_aa3]])
feature = np.stack([feature_part1, feature_part2, feature_part3]) # (3, N +1)
# 转换为张量并添加 batch 维度
feature_tensor = torch.tensor(feature, dtype=torch.float32).unsqueeze(0) # (1, 3, N +1)
# 进行预测
with torch.no_grad():
output = model(feature_tensor)
# 获取预测结果
predicted_values = output.squeeze(0).numpy() # (3,)
return predicted_values
def main():
import pandas as pd # 确保 pandas 已导入
# 模型文件路径
model_path = '../models/model.pth'
# 加载模型、编码器和氨基酸列表
try:
model, encoder, amino_acids = load_model(model_path)
except Exception as e:
print(f"加载模型失败:{e}")
return
# 加载单个氨基酸的数值字典
single_dict_path = '../data/single_dict.pkl'
if not os.path.exists(single_dict_path):
# 如果未保存,读取 single.csv 并创建字典,然后保存
single_csv = '../data/single.csv'
if not os.path.exists(single_csv):
raise FileNotFoundError(f"单个氨基酸数据文件 '{single_csv}' 不存在。")
single_df = pd.read_csv(single_csv, header=None, names=['AminoAcid', 'Value'])
single_dict = single_df.set_index('AminoAcid')['Value'].to_dict()
# 保存字典
with open(single_dict_path, 'wb') as f:
pickle.dump(single_dict, f)
else:
with open(single_dict_path, 'rb') as f:
single_dict = pickle.load(f)
# 提示用户输入氨基酸组合
while True:
aa_input = input("请输入氨基酸组合(例如 'A'、'AC'、'ACD'),输入 'exit' 退出:").strip().upper()
if aa_input.lower() == 'exit':
print("退出预测程序。")
break
try:
predicted = predict(model, encoder, amino_acids, single_dict, aa_input)
# 根据输入长度,打印对应数量的值
length = len(aa_input)
if length == 1:
print(f"预测结果 - Value1: {predicted[0]:.6f}")
elif length == 2:
print(f"预测结果 - Value1: {predicted[0]:.6f}, Value2: {predicted[1]:.6f}")
else:
print(f"预测结果 - Value1: {predicted[0]:.6f}, Value2: {predicted[1]:.6f}, Value3: {predicted[2]:.6f}")
except Exception as e:
print(f"错误:{e}")
if __name__ == '__main__':
main()
predict3
yyfhq2024-11-19 10:17
相关推荐
如若12322 分钟前
对文件内的文件名生成目录,方便查阅西猫雷婶1 小时前
python学opencv|读取图像(二十一)使用cv2.circle()绘制圆形进阶老刘莱国瑞1 小时前
STM32 与 AS608 指纹模块的调试与应用一只敲代码的猪2 小时前
Llama 3 模型系列解析(一)Hello_WOAIAI3 小时前
批量将 Word 文件转换为 HTML:Python 实现指南winfredzhang3 小时前
使用Python开发PPT图片提取与九宫格合并工具矩阵推荐官hy147623 小时前
短视频矩阵系统种类繁多,应该如何对比选择?测试19983 小时前
外包干了2年,技术退步明显....码银3 小时前
【python】银行客户流失预测预处理部分,独热编码·标签编码·数据离散化处理·数据筛选·数据分割小木_.4 小时前
【python 逆向分析某有道翻译】分析有道翻译公开的密文内容,webpack类型,全程扣代码,最后实现接口调用翻译,仅供学习参考