Python 机器学习求解 PDE 学习项目 基础知识(3)matplotlib 画函数热图

绘制模型输出的热图

前言

在科学计算和工程应用中,偏微分方程(PDE)的数值求解是一项常见且重要的任务。PDE的求解通常需要数值方法的辅助,例如有限元法、有限差分法和有限体积法等。这些方法能够提供对物理系统的精确模拟。然而,仅仅得到数值解是不够的,数据的可视化是理解和分析这些解的关键步骤。

matplotlib是Python中最广泛使用的绘图库之一。它为用户提供了强大的功能来可视化数据。在PDE的后处理阶段,利用matplotlib绘制热图、曲线图等能够帮助我们直观地观察到解的分布和变化趋势,从而进一步分析物理现象。

代码说明

本节代码展示了如何使用 matplotlib 和 torch(本案例使用 2.1 版本)来绘制一个简单模型的输出热图。由于深度学习求解结果通常是 torch 的 tensor 格式,所以我们这里也使用 tensor 数据格式。

导入必要的库

python 复制代码
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
  • numpy: 提供支持大型多维数组与矩阵运算,以及大量的数学函数库。
  • torch: 一个用于深度学习的开源机器学习框架。
  • matplotlib.pyplot: 一个用于生成图形的命令式接口。
  • matplotlib.cm: 包含用于着色的色彩映射函数。

函数定义:draw_graph

imshow 默认情况下会将 y 轴反向显示。这是因为 imshow 将数组的第一维作为y轴方向,而第二维作为 x 轴方向,从左上角开始绘制。为了纠正 y 轴方向,我们可以通过调整 extent 参数来确保 y 轴从底部向上增长。

python 复制代码
def draw_graph(mod, m):
    """
    绘制模型输出与真实值函数的组合图像,在输入值的网格上显示为热图。

    参数:
    - mod (callable): 接受张量输入并输出标量的模型。
    - m (int): 模型期望的输入张量大小。
    """
    # 创建范围为 [-1, 1] 的点网格
    N = 10
    points = np.linspace(-1, 1, N)
    xs, ys = np.meshgrid(points, points)

    # 将网格转换为 torch 张量
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)

    # 获取网格大小
    xl, yl = xs.size()

    # 初始化矩阵以保存计算值
    z = np.zeros((xl, yl))

    # 计算网格中每个点的模型输出和真实值
    for i in range(xl):
        for j in range(yl):      
            # 创建大小为 m 的输入张量
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)
            
            # 计算输出值
            z[i, j] = mod(re.float()).item()

    # 将结果绘制为热图,使用 extent 调整坐标系方向
    plt.imshow(z, cmap=cm.hot, extent=(-1, 1, -1, 1), origin='lower')
    plt.colorbar()

    # 自动生成轴的标签
    ax = plt.gca()
    x_ticks = np.linspace(-1, 1, N)
    y_ticks = np.linspace(-1, 1, N)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels([f'{x:.2f}' for x in x_ticks])
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f'{y:.2f}' for y in y_ticks])
    
    # 显示图像
    plt.show()

示例测试用例

python 复制代码
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        # 一个简单的模型,计算输入元素的和
        return x.sum()

实例化模型

python 复制代码
model = SimpleModel()

使用模型和输入大小 m = 2 测试 draw_graph 函数

python 复制代码
draw_graph(model, m=2)

示例测试用例

python 复制代码
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        # 一个简单的模型,计算输入元素的和
        return x.sum()

实例化模型

python 复制代码
model = SimpleModel()

使用模型和输入大小 m = 2 测试 draw_graph 函数

python 复制代码
draw_graph(model, m=2)

绘制结果:

例子2

我们可以改变热图的颜色映射 cmap ,网格密度 N 和 SimpleModel 中的函数函表达式:

python 复制代码
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm


def draw_graph(mod, m):
    """
    绘制模型输出与真实值函数的组合图像,在输入值的网格上显示为热图。

    参数:
    - mod (callable): 接受张量输入并输出标量的模型。
    - m (int): 模型期望的输入张量大小。
    """
    # 创建范围为 [-1, 1] 的点网格
    N = 20
    points = np.linspace(-1, 1, N)
    xs, ys = np.meshgrid(points, points)

    # 将网格转换为 torch 张量
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)

    # 获取网格大小
    xl, yl = xs.size()

    # 初始化矩阵以保存计算值
    z = np.zeros((xl, yl))

    # 计算网格中每个点的模型输出和真实值
    for i in range(xl):
        for j in range(yl):      
            # 创建大小为 m 的输入张量
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)
            
            # 计算输出值
            z[i, j] = mod(re.float()).item()

    # 将结果绘制为热图,使用 coolwarm cmap
    plt.imshow(z, cmap=cm.coolwarm, extent=(-1, 1, -1, 1), origin='lower')
    plt.colorbar()

    # 自动生成轴的标签
    ax = plt.gca()
    x_ticks = np.linspace(-1, 1, 11)
    y_ticks = np.linspace(-1, 1, 11)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels([f'{x:.1f}' for x in x_ticks])
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f'{y:.1f}' for y in y_ticks])
    
    # 显示图像
    plt.show()

# 示例测试用例
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        # 修改模型表达式,例如使用 sin 函数和 cos 函数
        return torch.sin(x[0]) + torch.cos(x[1])

# 实例化模型
model = SimpleModel()

# 使用模型和输入大小 m = 2 测试 draw_graph 函数
draw_graph(model, m=2)

结果如下:

非正方形区域如何绘图?

例如我们想要要绘制一个三角形区域而不是矩形,我们需要对点的选择和处理进行一些更改,以确保只计算和绘制位于三角形区域内的点。具体来说,可以在遍历网格点时只保留位于特定三角形内的点来计算模型的输出。

假设我们想要画出
x + y < = 0 x+y<=0 x+y<=0 以下的三角形区域,只需要在遍历时加入判断语句,区域之外的点值设置为 NaN 即可:

python 复制代码
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def draw_graph(mod, m):
    """
    绘制模型输出与真实值函数的组合图像,在输入值的三角形网格上显示为热图。

    参数:
    - mod (callable): 接受张量输入并输出标量的模型。
    - m (int): 模型期望的输入张量大小。
    """
    # 创建范围为 [-1, 1] 的点网格
    N = 100  # 增加点的密度以获得更高分辨率的图像
    points = np.linspace(-1, 1, N)
    xs, ys = np.meshgrid(points, points)

    # 将网格转换为 torch 张量
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)

    # 获取网格大小
    xl, yl = xs.size()

    # 初始化矩阵以保存计算值
    z = np.zeros((xl, yl))

    # 计算三角形内每个点的模型输出
    for i in range(xl):
        for j in range(yl):      
            # 检查点是否在三角形内
            if xs[i, j] >= -1 and ys[i, j] >= -1 and ys[i, j] + xs[i, j] <=  0:
                # 创建大小为 m 的输入张量
                re = np.zeros(m)
                re[0] = xs[i, j]
                re[1] = ys[i, j]
                # 新增第三个维度使用固定值,例如 0.1
                re[2] = 0.1  
                re = torch.tensor(re)
                
                # 计算输出值
                z[i, j] = mod(re.float()).item()
            else:
                z[i, j] = np.nan  # 对于三角形之外的点,设为 NaN,不显示

    # 将结果绘制为热图,使用 coolwarm cmap
    plt.imshow(z, cmap=cm.coolwarm, extent=(-1, 1, -1, 1), origin='lower')
    plt.colorbar()

    # 自动生成轴的标签
    ax = plt.gca()
    x_ticks = np.linspace(-1, 1, 11)
    y_ticks = np.linspace(-1, 1, 11)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels([f'{x:.1f}' for x in x_ticks])
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f'{y:.1f}' for y in y_ticks])
    
    # 显示图像
    plt.show()

# 示例测试用例
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        # 修改模型表达式,使用 sin 和 cos 函数并结合第三个维度
        return torch.sin(x[0]) + torch.cos(x[1]) + 2*x[2]**2

# 实例化模型
model = SimpleModel()

# 使用模型和输入大小 m = 3 测试 draw_graph 函数
draw_graph(model, m=3)

绘制效果图:

对于其他多边形区域,也可以类似切割,并加入判断语句,实现在指定区域画图的效果。


本专栏致力于普及各种偏微分方程的不同数值求解方法,所有文章包含全部可运行代码。欢迎大家支持、关注!

作者 :计算小屋
个人主页计算小屋的主页

相关推荐
sg_knight1 分钟前
Python 面向对象基础复习
开发语言·python·ai编程·面向对象·模型
刘洋浪子7 分钟前
Git命令学习
git·学习·elasticsearch
dhdjjsjs27 分钟前
Day35 PythonStudy
python
如竟没有火炬1 小时前
四数相加贰——哈希表
数据结构·python·算法·leetcode·散列表
JoannaJuanCV1 小时前
自动驾驶—CARLA仿真(5)Actors与Blueprints
人工智能·机器学习·自动驾驶
大白的编程日记.1 小时前
【计算网络学习笔记】Socket编程UDP实现简单聊天室
网络·笔记·学习
背心2块钱包邮1 小时前
第9节——部分分式积分(Partial Fraction Decomposition)
人工智能·python·算法·机器学习·matplotlib
木盏1 小时前
三维高斯的分裂
开发语言·python
a程序小傲1 小时前
京东Java面试被问:ZGC的染色指针如何实现?内存屏障如何处理?
java·后端·python·面试
serve the people1 小时前
如何区分什么场景下用机器学习,什么场景下用深度学习
人工智能·深度学习·机器学习