注释标准模板

观看main函数能够看出框架,框架要简单,比如训练不给它细分,数据流向关注转为哪个数据,而不是关注维度,维度在调试的时候才关注

1、=>表示数据流向

2、# ============================================================================ #包围的表示框架

3、# 表示普通的框架内的注释

4、# -----补充:-----表示补充的注释

5、总体框架以及流向以及说明写在文件最开头

python 复制代码
"""
基于PSO的BP神经网络预测 - PyTorch版本
框架:
数据准备(设备设置|数据加载)→数据转换→参数设置→训练→结果展示以及预测

数据流向:
# =>加载数据P、T、P_test、T_test、cur_season=>数据转换P_tensor、T_tensor、P_test_tensor、T_test_tensor、cur_season_tensor
# =>particle粒子=>particle粒子、cost、GlobalBest、BestCost=>预测的结果值prob
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
import time
import torch
import torch.nn as nn
import torch.optim as optim
from bp_function import BpFunction

np.random.seed(42)
torch.manual_seed(42)

# -----补充:粒子类定义-----
class Particle:
    def __init__(self, position, velocity, cost, best_position, best_cost):
        self.Position = position # 粒子位置
        self.Velocity = velocity # 粒子速度
        self.Cost = cost # 粒子当前适应度
        self.Best = {'Position': best_position, 'Cost': best_cost} # 粒子个体最优

def main():
    start_time = time.time()
    
    # ============================================================================ #
    # 设备设置------最终数据都需要经过这个设备来进行计算
    # ============================================================================ #
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # =>加载数据P、T、P_test、T_test、cur_season
    # 数据加载
    print("加载数据...")
    P = np.loadtxt('P_train.txt')
    T = np.loadtxt('T_train.txt').reshape(-1, 1)
    P_test = np.loadtxt('P_test.txt')
    T_test = np.loadtxt('T_test.txt').reshape(-1, 1)
    cur_season = np.loadtxt('2022.txt')
    
    # =>加载数据P、T、P_test、T_test、cur_season=>数据转换P_tensor、T_tensor、P_test_tensor、T_test_tensor、cur_season_tensor
    # ============================================================================ #
    # # 数据转换------转换到PyTorch张量并移动到设备上
    # ============================================================================ #
    P_tensor = torch.FloatTensor(P).to(device)
    T_tensor = torch.FloatTensor(T).to(device)
    P_test_tensor = torch.FloatTensor(P_test).to(device)
    T_test_tensor = torch.FloatTensor(T_test).to(device)
    cur_season_tensor = torch.FloatTensor(cur_season).to(device)
    
    # ============================================================================ #
    # # 参数设置------训练参数设置1-BP参数
    # ============================================================================ #
    inputnum = P.shape[1]                    # 输入层神经元个数
    hiddennum = 2 * inputnum + 1             # 初始隐层神经元个数
    outputnum = T.shape[1]                  # 输出层神经元个数
    
    w1num = inputnum * hiddennum             # 输入层到隐层的权值个数
    w2num = outputnum * hiddennum            # 隐层到输出层的权值个数
    N = w1num + hiddennum + w2num + outputnum  # 待优化的变量个数
    print(f"网络结构: {inputnum} -> {hiddennum} -> {outputnum}")
    print(f"待优化变量个数: {N}")
    
    # ============================================================================ #
    # # 参数设置------训练参数设置2-PSO参数设置
    # ============================================================================ #
    nVar = N
    VarMin = -0.5
    VarMax = 0.5
    MaxIt = 200
    nPop = 40
    w = 1.0
    wdamp = 0.99
    c1 = 1.5
    c2 = 2.0
    VelMax = 0.1 * (VarMax - VarMin)
    VelMin = -VelMax
    

    # -----补充:粒子变量保存以及最佳解保存-----
    particles = []
    GlobalBest = {'Position': None, 'Cost': np.inf}
    
    # =>加载数据P、T、P_test、T_test、cur_season=>数据转换P_tensor、T_tensor、P_test_tensor、T_test_tensor、cur_season_tensor
    # =>particle粒子(获得第一次BP网络训练之后的结果,得到当前的最佳)
    # ============================================================================ #
    # 训练
    # ============================================================================ #
    print("初始化粒子群...")
    for i in range(nPop):
        position = np.random.uniform(VarMin, VarMax, nVar)
        velocity = np.zeros(nVar)
        cost, _ = BpFunction(position, P_tensor, T_tensor, hiddennum, 
                            P_test_tensor, T_test_tensor, device)
        best_position = position.copy()
        best_cost = cost
        
        particle = Particle(position, velocity, cost, best_position, best_cost)
        particles.append(particle)
        
        if best_cost < GlobalBest['Cost']:
            GlobalBest['Position'] = best_position.copy()
            GlobalBest['Cost'] = best_cost
    
    # 初始化变量,之后用来记录最优成本
    BestCost = np.zeros(MaxIt)
    
    # =>加载数据P、T、P_test、T_test、cur_season=>数据转换P_tensor、T_tensor、P_test_tensor、T_test_tensor、cur_season_tensor
    # =>particle粒子=>particle粒子、cost、GlobalBest、BestCost(PSO训练)
    print("开始PSO优化...")
    for it in range(MaxIt):
        for i in range(nPop):
            # 更新速度
            r1 = np.random.rand(nVar)
            r2 = np.random.rand(nVar)
            particles[i].Velocity = (w * particles[i].Velocity +
                                     c1 * r1 * (particles[i].Best['Position'] - particles[i].Position) +
                                     c2 * r2 * (GlobalBest['Position'] - particles[i].Position))
            # 速度限制
            particles[i].Velocity = np.clip(particles[i].Velocity, VelMin, VelMax)
            # 更新位置
            particles[i].Position = particles[i].Position + particles[i].Velocity
            # 位置边界处理(反弹)
            IsOutside = (particles[i].Position < VarMin) | (particles[i].Position > VarMax)
            particles[i].Velocity[IsOutside] = -particles[i].Velocity[IsOutside]
            particles[i].Position = np.clip(particles[i].Position, VarMin, VarMax)
            # 评估适应度
            cost, _ = BpFunction(particles[i].Position, P_tensor, T_tensor, hiddennum,
                                P_test_tensor, T_test_tensor, device)
            particles[i].Cost = cost
            # 更新个体最优
            if cost < particles[i].Best['Cost']:
                particles[i].Best['Position'] = particles[i].Position.copy()
                particles[i].Best['Cost'] = cost
                # 更新全局最优
                if cost < GlobalBest['Cost']:
                    GlobalBest['Position'] = particles[i].Position.copy()
                    GlobalBest['Cost'] = cost
        BestCost[it] = GlobalBest['Cost']
        print(f"Iteration {it+1}: Best Cost = {BestCost[it]:.6f}")
        w = w * wdamp
    
    # ============================================================================ #
    # # 结果展示
    # ============================================================================ #
    plt.figure(figsize=(10, 6))
    plt.semilogy(BestCost, linewidth=2)
    plt.xlabel('迭代次数', fontsize=12)
    plt.ylabel('误差的变化', fontsize=12)
    plt.title('进化过程', fontsize=14)
    plt.grid(True)
    plt.savefig('evolution_pso.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # =>加载数据P、T、P_test、T_test、cur_season=>数据转换P_tensor、T_tensor、P_test_tensor、T_test_tensor、cur_season_tensor
    # =>particle粒子=>particle粒子、cost、GlobalBest、BestCost=>预测的结果值prob
    # ============================================================================ #
    # # 预测
    # ============================================================================ #
    print(f"\n最优初始权值和阈值数量: {len(GlobalBest['Position'])}")
    print(f"最小误差: {GlobalBest['Cost']:.6f}")
    print("\n预测今年总冠军概率...")
    _, bestCur_sim = BpFunction(GlobalBest['Position'], P_tensor, T_tensor, hiddennum,
                                cur_season_tensor, None, device)
    prob = softmax(bestCur_sim.flatten())#将预测结果映射为和为1的概率
    print(f"勇士队获得2022年NBA总冠军概率为: {prob[0]:.4f}")
    print(f"凯尔特人队获得2022年NBA总冠军概率为: {prob[1]:.4f}")
    
    elapsed_time = time.time() - start_time
    print(f"\n总耗时: {elapsed_time:.2f} 秒")

if __name__ == "__main__":
    main()
相关推荐
Cosmoshhhyyy2 小时前
《Effective Java》解读第46条:优先选择Stream中无副作用的函数
java·windows·python
gf13211112 小时前
流光剪辑_调用生成图片模型/apimart调用生成视频模型
python
chenglin0162 小时前
Semantic Kernel 内核详解
后端·python·flask
B站_计算机毕业设计之家2 小时前
计算机毕业设计:Python城市地铁网络可视化分析系统 Flask框架 数据分析 可视化 高德地图 数据挖掘 机器学习 爬虫(建议收藏)✅
网络·python·信息可视化·数据挖掘·flask·课程设计·美食
源码之家2 小时前
计算机毕业设计:Python地铁数据可视化分析系统 Flask框架 数据分析 可视化 高德地图 数据挖掘 机器学习 爬虫(建议收藏)✅
大数据·python·信息可视化·数据挖掘·flask·汽车·课程设计
zhishidi3 小时前
使用python给pdf文档自动添加目录书签
java·python·pdf
chushiyunen12 小时前
python中的@Property和@Setter
java·开发语言·python
禾小西12 小时前
Java中使用正则表达式核心解析
java·python·正则表达式