predict3

复制代码
# predict.py

import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import OneHotEncoder
from model import FCNet
import pickle
import os
import sklearn
from packaging import version


# 自定义函数用于加载模型和编码器
def load_model(model_path):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"模型文件 '{model_path}' 不存在。")

    # 加载模型文件
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    encoder = checkpoint['encoder']
    amino_acids = checkpoint['amino_acids']

    # 定义模型
    feature_dim1 = 3  # 最多三个氨基酸
    feature_dim2 = len(amino_acids) + 1  # one-hot + 数值特征
    model = FCNet(feature_dim1, feature_dim2)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    return model, encoder, amino_acids


# 自定义函数进行预测
def predict(model, encoder, amino_acids, single_dict, aa_input):
    # 检查输入类型和长度
    if not isinstance(aa_input, str):
        raise ValueError("输入的氨基酸组合必须是字符串,例如 'A', 'AC', 'ACD'。")

    aa_input = aa_input.upper()
    length = len(aa_input)
    if length == 0 or length > 3:
        raise ValueError("输入的氨基酸组合长度必须在1到3之间。")

    # 填充至三个氨基酸
    if length == 1:
        aa_triplet = aa_input + '--'
    elif length == 2:
        aa_triplet = aa_input + '-'
    else:
        aa_triplet = aa_input

    # 检查是否所有氨基酸都在编码器中
    for aa in aa_triplet:
        if aa not in amino_acids:
            raise ValueError(f"氨基酸 '{aa}' 不在编码器的氨基酸列表中。")

    # 独热编码
    try:
        one_hot1 = encoder.transform([[aa_triplet[0]]])[0]
        one_hot2 = encoder.transform([[aa_triplet[1]]])[0]
        one_hot3 = encoder.transform([[aa_triplet[2]]])[0]
    except Exception as e:
        raise ValueError(f"无法对氨基酸组合 '{aa_triplet}' 进行编码。错误信息:{e}")

    # 读取单个氨基酸的数值特征
    value1_aa1 = single_dict.get(aa_triplet[0], 0.0)
    value1_aa2 = single_dict.get(aa_triplet[1], 0.0)
    value1_aa3 = single_dict.get(aa_triplet[2], 0.0)

    # 构建特征向量
    feature_part1 = np.concatenate([one_hot1, [value1_aa1]])
    feature_part2 = np.concatenate([one_hot2, [value1_aa2]])
    feature_part3 = np.concatenate([one_hot3, [value1_aa3]])
    feature = np.stack([feature_part1, feature_part2, feature_part3])  # (3, N +1)

    # 转换为张量并添加 batch 维度
    feature_tensor = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)  # (1, 3, N +1)

    # 进行预测
    with torch.no_grad():
        output = model(feature_tensor)

    # 获取预测结果
    predicted_values = output.squeeze(0).numpy()  # (3,)
    return predicted_values


def main():
    import pandas as pd  # 确保 pandas 已导入

    # 模型文件路径
    model_path = '../models/model.pth'

    # 加载模型、编码器和氨基酸列表
    try:
        model, encoder, amino_acids = load_model(model_path)
    except Exception as e:
        print(f"加载模型失败:{e}")
        return

    # 加载单个氨基酸的数值字典
    single_dict_path = '../data/single_dict.pkl'
    if not os.path.exists(single_dict_path):
        # 如果未保存,读取 single.csv 并创建字典,然后保存
        single_csv = '../data/single.csv'
        if not os.path.exists(single_csv):
            raise FileNotFoundError(f"单个氨基酸数据文件 '{single_csv}' 不存在。")
        single_df = pd.read_csv(single_csv, header=None, names=['AminoAcid', 'Value'])
        single_dict = single_df.set_index('AminoAcid')['Value'].to_dict()
        # 保存字典
        with open(single_dict_path, 'wb') as f:
            pickle.dump(single_dict, f)
    else:
        with open(single_dict_path, 'rb') as f:
            single_dict = pickle.load(f)

    # 提示用户输入氨基酸组合
    while True:
        aa_input = input("请输入氨基酸组合(例如 'A'、'AC'、'ACD'),输入 'exit' 退出:").strip().upper()
        if aa_input.lower() == 'exit':
            print("退出预测程序。")
            break
        try:
            predicted = predict(model, encoder, amino_acids, single_dict, aa_input)
            # 根据输入长度,打印对应数量的值
            length = len(aa_input)
            if length == 1:
                print(f"预测结果 - Value1: {predicted[0]:.6f}")
            elif length == 2:
                print(f"预测结果 - Value1: {predicted[0]:.6f}, Value2: {predicted[1]:.6f}")
            else:
                print(f"预测结果 - Value1: {predicted[0]:.6f}, Value2: {predicted[1]:.6f}, Value3: {predicted[2]:.6f}")
        except Exception as e:
            print(f"错误:{e}")


if __name__ == '__main__':
    main()
相关推荐
七夜zippoe1 分钟前
异步编程实战:构建高性能Python网络应用
开发语言·python·websocket·asyncio·aiohttp
tianyuanwo2 分钟前
Python虚拟环境深度解析:从virtualenv到virtualenvwrapper
开发语言·python·virtualenv
越甲八千2 分钟前
ORM 的优势
数据库·python
是有头发的程序猿4 分钟前
Python爬虫防AI检测实战指南:从基础到高级的规避策略
人工智能·爬虫·python
grd45 分钟前
Electron for OpenHarmony 实战:Pagination 分页组件实现
python·学习
CryptoRzz6 分钟前
印度交易所 BSE 与 NSE 实时数据 API 接入指南
java·c语言·python·区块链·php·maven·symfony
山土成旧客13 分钟前
【Python学习打卡-Day35】从黑盒到“玻璃盒”:掌握PyTorch模型可视化、进度条与推理
pytorch·python·学习
@zulnger14 分钟前
python 学习笔记(循环)
笔记·python·学习
No_Merman20 分钟前
【DAY28】元组和os模块
python
iuu_star35 分钟前
金融数据-基于Streamlit的金融数据分析平台开发详解
python·金融·数据挖掘