Matplotlib绘图详解

Matplotlib绘图详解

  • [1. Matplotlib(plt)绘图详解](#1. Matplotlib(plt)绘图详解)
  • [2. 基础折线图(Line Plot)](#2. 基础折线图(Line Plot))
  • [3. 多条曲线绘制](#3. 多条曲线绘制)
  • [4. 曲线区域填充](#4. 曲线区域填充)
  • [5. 散点图(Scatter Plot)](#5. 散点图(Scatter Plot))
  • [6. 柱状图(Bar Plot)](#6. 柱状图(Bar Plot))
  • [7. 直方图(Histogram)](#7. 直方图(Histogram))
  • [8. 箱线图(Box Plot)](#8. 箱线图(Box Plot))
  • [9. 饼图(Pie Chart)](#9. 饼图(Pie Chart))
  • [10. 热力图(Heatmap)](#10. 热力图(Heatmap))
  • [11. 混淆矩阵(Confusion Matrix)](#11. 混淆矩阵(Confusion Matrix))
  • [12. 三维图](#12. 三维图)
  • [13. 多子图绘制](#13. 多子图绘制)
  • [14. 图像保存](#14. 图像保存)
  • [15. 科研绘图常用全局配置](#15. 科研绘图常用全局配置)

1. Matplotlib(plt)绘图详解

Matplotlib简介

Matplotlib 是 Python 中最流行的数据可视化库之一,也是科研领域使用最广泛的绘图库之一。其中 matplotlib.pyplot 模块提供了类似 MATLAB 的绘图接口,因此通常采用如下方式导入:

python 复制代码
import matplotlib.pyplot as plt
import numpy as np

Matplotlib 具有以下特点:

  • 功能丰富,支持二维和三维绘图
  • 与 NumPy、Pandas、PyTorch 等科学计算库兼容性良好
  • 可高度自定义图像样式
  • 支持导出高分辨率图片和矢量图
  • 广泛应用于机器学习、深度学习、信号处理等领域

Matplotlib 常用于:

  • 展示训练 Loss 曲线
  • 展示 Accuracy 曲线
  • 绘制混淆矩阵
  • 绘制 t-SNE 特征分布图
  • 绘制模型性能对比图

2. 基础折线图(Line Plot)

折线图是最常见的图表类型,用于展示数据随时间或序列变化的趋势。

例如,在深度学习训练过程中,经常使用折线图观察 Loss 和 Accuracy 的变化情况。

基本绘制方法

python 复制代码
import matplotlib.pyplot as plt

x = [1,2,3,4,5]
y = [2,4,3,6,8]

plt.plot(x, y)

plt.show()

其中:

  • x 为横坐标数据
  • y 为纵坐标数据
  • plot() 用于绘制折线
  • show() 用于显示图像

添加标题和坐标轴标签

python 复制代码
plt.plot(x, y)

plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.show()

常用设置:

函数 作用
title() 设置标题
xlabel() 设置横坐标名称
ylabel() 设置纵坐标名称

设置线条样式

python 复制代码
plt.plot(
    x,
    y,
    color='red',
    linestyle='--',
    linewidth=2,
    marker='o'
)

plt.show()

参数说明:

参数 含义
color 颜色
linestyle 线型
linewidth 线宽
marker 数据点样式

常用颜色:

python 复制代码
'red'
'blue'
'green'
'black'
'yellow'

常用线型:

python 复制代码
'-'     # 实线
'--'    # 虚线
'-.'    # 点划线
':'     # 点线

常用标记:

python 复制代码
'o'     # 圆点
's'     # 方块
'^'     # 三角形
'*'     # 星号

3. 多条曲线绘制

在实验对比中,经常需要将多个模型的结果绘制在同一张图中。

python 复制代码
epoch = [1,2,3,4,5]

model1 = [60,70,78,83,85]
model2 = [58,72,80,86,88]
model3  = [65,78,85,90,93]

plt.plot(epoch, model1, label='1')
plt.plot(epoch, model2, label='2')
plt.plot(epoch, model3, label='3')

plt.legend()

plt.show()

图例设置

python 复制代码
plt.legend(loc='upper left')

常见位置:

参数 位置
upper left 左上
upper right 右上
lower left 左下
lower right 右下

科研论文中一般放在:

python 复制代码
upper right

或者:

python 复制代码
upper left

避免遮挡曲线。


4. 曲线区域填充

为了增强可视化效果,可以对曲线下方区域进行填充。

python 复制代码
plt.plot(x, y)

plt.fill_between(
    x,
    y,
    alpha=0.3
)

plt.show()

其中:

python 复制代码
alpha

表示透明度。

取值范围:

python 复制代码
0 ~ 1

应用场景:

  • Loss 曲线
  • Accuracy 曲线
  • 置信区间展示
  • 误差带展示

5. 散点图(Scatter Plot)

散点图主要用于展示样本分布情况,中最常见的应用是:

  • t-SNE 可视化
  • PCA 可视化
  • 特征聚类展示

基础示例

python 复制代码
x = np.random.rand(100)
y = np.random.rand(100)

plt.scatter(x, y)

plt.show()

分类散点图

python 复制代码
plt.scatter(x1, y1, label='1')

plt.scatter(x2, y2, label='2')

plt.scatter(x3, y3, label='3')

plt.legend()

plt.show()

用于观察:

  • 类内紧凑程度
  • 类间分离程度

6. 柱状图(Bar Plot)

柱状图主要用于比较不同类别数据之间的大小关系。

例如:

  • 不同模型准确率比较
  • 不同算法运行时间比较
python 复制代码
models = ['1','2','3','4']
acc = [85,88,90,95]

plt.bar(models, acc)

plt.show()

显示数值标签

python 复制代码
bars = plt.bar(models, acc)

for bar in bars:

    plt.text(
        bar.get_x()+0.2,
        bar.get_height(),
        str(bar.get_height())
    )

plt.show()

7. 直方图(Histogram)

直方图用于展示数据分布情况。

python 复制代码
data = np.random.randn(1000)

plt.hist(
    data,
    bins=30
)

plt.show()

参数:

python 复制代码
bins=30

表示划分为30个区间。


概率密度直方图

python 复制代码
plt.hist(
    data,
    bins=30,
    density=True
)

常用于:

  • 特征值统计分析
  • 数据集分布分析

8. 箱线图(Box Plot)

箱线图用于展示数据统计特征。

python 复制代码
data = [
    np.random.randn(100),
    np.random.randn(100)+1
]

plt.boxplot(data)

plt.show()

箱线图可以展示:

  • 最大值
  • 最小值
  • 中位数
  • 上四分位数
  • 下四分位数
  • 异常值

适用于:

  • 多次实验稳定性分析
  • 模型鲁棒性分析

9. 饼图(Pie Chart)

用于展示各部分所占比例。

python 复制代码
sizes = [40,30,20,10]

labels = [
    '1',
    '2',
    '3',
    '4'
]

plt.pie(
    sizes,
    labels=labels,
    autopct='%1.1f%%'
)

plt.show()

应用场景:

  • 数据集类别比例

10. 热力图(Heatmap)

热力图利用颜色表示数值大小。

python 复制代码
data = np.random.rand(10,10)

plt.imshow(
    data,
    cmap='jet'
)

plt.colorbar()

plt.show()

常用颜色映射

python 复制代码
cmap='jet'
cmap='hot'
cmap='viridis'
cmap='coolwarm'

应用:

  • 注意力权重可视化
  • 特征响应图
  • 卷积特征图

11. 混淆矩阵(Confusion Matrix)

分类任务中最重要的评价图之一。

python 复制代码
cm = np.array([
    [95,2,1],
    [3,92,5],
    [1,4,95]
])

plt.imshow(
    cm,
    cmap='Blues'
)

plt.colorbar()

plt.show()

通过混淆矩阵可以分析:

  • 哪些类别容易混淆
  • 分类器识别能力
  • 类别间可分性

12. 三维图

python 复制代码
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()

ax = fig.add_subplot(
    111,
    projection='3d'
)

ax.scatter(x,y,z)

plt.show()

应用:

  • 三维特征空间展示
  • 三维 t-SNE
  • 三维轨迹分析

13. 多子图绘制

论文中经常需要在同一张图中展示多个结果。

python 复制代码
fig, ax = plt.subplots(
    2,
    2,
    figsize=(10,8)
)

示例:

python 复制代码
ax[0,0].plot(x,y)

ax[0,1].scatter(x,y)

ax[1,0].hist(data)

ax[1,1].bar(models,acc)

适用于:

  • 多模型结果对比

14. 图像保存

论文投稿时必须掌握。

python 复制代码
plt.savefig(
    "result.png",
    dpi=300,
    bbox_inches='tight'
)

参数说明:

参数 作用
dpi 分辨率
bbox_inches 去除空白边缘

推荐保存格式

位图:

python 复制代码
plt.savefig("result.png")

矢量图:

python 复制代码
plt.savefig("result.svg")
python 复制代码
plt.savefig("result.pdf")

科研推荐:

python 复制代码
dpi = 600

15. 科研绘图常用全局配置

统一图像风格能够显著提升论文质量。

python 复制代码
plt.rcParams['font.size'] = 14

plt.rcParams['figure.figsize'] = (8,6)

plt.rcParams['axes.unicode_minus'] = False

中文支持:

python 复制代码
plt.rcParams['font.sans-serif'] = ['SimHei']