一维热传导方程的PINN求解——损失函数实时绘制

一维热传导方程的PINN求解------损失函数实时绘制

python 复制代码
# -*- coding: utf-8 -*-
######### ---- 精简版 PINN 求解一维热传导方程 ---- #########

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn as nn
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt

# ======================
# 基础设置
# ======================
def setup_seed(seed=20):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

setup_seed(20)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ======================
# PINN 网络
# ======================
class PINN(nn.Module):
    def __init__(self, layers=[2, 20, 20, 20, 20, 20, 1]):
        super().__init__()
        net = []
        for i in range(len(layers)-2):
            net.append(nn.Linear(layers[i], layers[i+1]))
            net.append(nn.Tanh())
        net.append(nn.Linear(layers[-2], layers[-1]))
        self.net = nn.Sequential(*net)

    def forward(self, t, x):
        return self.net(torch.cat([t, x], dim=1))

# ======================
# 封装偏导函数
# ======================
def grad(u, var):
    return torch.autograd.grad(
        u, var, grad_outputs=torch.ones_like(u),
        create_graph=True, retain_graph=True
    )[0]

# ======================
# 损失函数
# ======================
def physics_loss(model, t, x, alpha=1.0):
    T = model(t, x)
    T_t = grad(T, t)
    T_xx = grad(grad(T, x), x)
    return torch.mean((T_t - alpha * T_xx) ** 2)

def boundary_loss(model, t_bc, x_left, x_right):
    T_left = model(t_bc, x_left)
    T_right = model(t_bc, x_right)
    loss_left = torch.mean((grad(T_left, x_left) + 1.0) ** 2)
    loss_right = torch.mean((grad(T_right, x_right) - 0.0) ** 2)
    return loss_left + loss_right

def initial_loss(model, x_ic):
    t0 = torch.zeros_like(x_ic, device=device)
    T0 = model(t0, x_ic)
    return torch.mean((T0 - 0.0) ** 2)

# ======================
# 训练函数(实时绘图)
# ======================
def train(model, optimizer, num_epochs=5000):
    # -------- 采样点 --------
    N_f, N_bc, N_ic = 8000, 80, 100

    t = torch.rand(N_f,1,device=device,requires_grad=True)
    x = torch.rand(N_f,1,device=device,requires_grad=True)

    t_bc = torch.rand(N_bc,1,device=device,requires_grad=True)
    x_left = torch.zeros(N_bc,1,device=device,requires_grad=True)
    x_right = torch.ones(N_bc,1,device=device,requires_grad=True)

    x_ic = torch.rand(N_ic,1,device=device)

    # -------- Loss 记录 --------
    loss_f_list, loss_bc_list, loss_ic_list, loss_total_list = [], [], [], []

    # -------- 实时绘图 --------
    plt.ion()
    fig, ax = plt.subplots(figsize=(6,4))

    for epoch in tqdm(range(num_epochs), desc="Training"):
        optimizer.zero_grad()
        loss_f = physics_loss(model, t, x)
        loss_bc = boundary_loss(model, t_bc, x_left, x_right)
        loss_ic = initial_loss(model, x_ic)
        loss = loss_f + loss_bc + loss_ic
        loss.backward()
        optimizer.step()

        # 记录
        loss_f_list.append(loss_f.item())
        loss_bc_list.append(loss_bc.item())
        loss_ic_list.append(loss_ic.item())
        loss_total_list.append(loss.item())

        # -------- 实时更新曲线 --------
        if epoch % 50 == 0:
            ax.clear()
            ax.semilogy(loss_f_list, label="Physics")
            ax.semilogy(loss_bc_list, label="BC")
            ax.semilogy(loss_ic_list, label="IC")
            ax.semilogy(loss_total_list, label="Total")
            ax.set_xlabel("Epoch")
            ax.set_ylabel("Loss")
            ax.legend()
            ax.set_title(f"Epoch {epoch}")
            plt.pause(0.01)

        if epoch % 500 == 0:
            print(f"[{epoch}] Total={loss.item():.3e} | "
                  f"F={loss_f.item():.3e} | BC={loss_bc.item():.3e} | IC={loss_ic.item():.3e}")

    plt.ioff()
    plt.show()
    return loss_f_list, loss_bc_list, loss_ic_list, loss_total_list

# ======================
# 主程序
# ======================
model = PINN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

loss_f, loss_bc, loss_ic, loss_total = train(model, optimizer, num_epochs=5000)
相关推荐
她说..4 小时前
Java 对象相关高频面试题
java·开发语言·spring·java-ee
花酒锄作田4 小时前
Postgres - Listen/Notify构建轻量级发布订阅系统
python·postgresql
watson_pillow5 小时前
c++ 协程的初步理解
开发语言·c++
庞轩px5 小时前
深入理解 sleep() 与 wait():从基础到监视器队列
java·开发语言·线程··wait·sleep·监视器
Thomas.Sir5 小时前
第二章:LlamaIndex 的基本概念
人工智能·python·ai·llama·llamaindex
故事和你915 小时前
洛谷-算法1-2-排序2
开发语言·数据结构·c++·算法·动态规划·图论
m0_694845575 小时前
Dify部署教程:从AI原型到生产系统的一站式方案
服务器·人工智能·python·数据分析·开源
白毛大侠6 小时前
理解 Go 接口:eface 与 iface 的区别及动态性解析
开发语言·网络·golang
李昊哲小课6 小时前
Python办公自动化教程 - 第7章 综合实战案例 - 企业销售管理系统
开发语言·python·数据分析·excel·数据可视化·openpyxl
Hou'6 小时前
从0到1的C语言传奇之路
c语言·开发语言