Matplotlib绘图详解
- [1. Matplotlib(plt)绘图详解](#1. Matplotlib(plt)绘图详解)
- [2. 基础折线图(Line Plot)](#2. 基础折线图(Line Plot))
- [3. 多条曲线绘制](#3. 多条曲线绘制)
- [4. 曲线区域填充](#4. 曲线区域填充)
- [5. 散点图(Scatter Plot)](#5. 散点图(Scatter Plot))
- [6. 柱状图(Bar Plot)](#6. 柱状图(Bar Plot))
- [7. 直方图(Histogram)](#7. 直方图(Histogram))
- [8. 箱线图(Box Plot)](#8. 箱线图(Box Plot))
- [9. 饼图(Pie Chart)](#9. 饼图(Pie Chart))
- [10. 热力图(Heatmap)](#10. 热力图(Heatmap))
- [11. 混淆矩阵(Confusion Matrix)](#11. 混淆矩阵(Confusion Matrix))
- [12. 三维图](#12. 三维图)
- [13. 多子图绘制](#13. 多子图绘制)
- [14. 图像保存](#14. 图像保存)
- [15. 科研绘图常用全局配置](#15. 科研绘图常用全局配置)
1. Matplotlib(plt)绘图详解
Matplotlib简介
Matplotlib 是 Python 中最流行的数据可视化库之一,也是科研领域使用最广泛的绘图库之一。其中 matplotlib.pyplot 模块提供了类似 MATLAB 的绘图接口,因此通常采用如下方式导入:
python
import matplotlib.pyplot as plt
import numpy as np
Matplotlib 具有以下特点:
- 功能丰富,支持二维和三维绘图
- 与 NumPy、Pandas、PyTorch 等科学计算库兼容性良好
- 可高度自定义图像样式
- 支持导出高分辨率图片和矢量图
- 广泛应用于机器学习、深度学习、信号处理等领域
Matplotlib 常用于:
- 展示训练 Loss 曲线
- 展示 Accuracy 曲线
- 绘制混淆矩阵
- 绘制 t-SNE 特征分布图
- 绘制模型性能对比图
2. 基础折线图(Line Plot)
折线图是最常见的图表类型,用于展示数据随时间或序列变化的趋势。
例如,在深度学习训练过程中,经常使用折线图观察 Loss 和 Accuracy 的变化情况。
基本绘制方法
python
import matplotlib.pyplot as plt
x = [1,2,3,4,5]
y = [2,4,3,6,8]
plt.plot(x, y)
plt.show()
其中:
- x 为横坐标数据
- y 为纵坐标数据
- plot() 用于绘制折线
- show() 用于显示图像
添加标题和坐标轴标签
python
plt.plot(x, y)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
常用设置:
| 函数 | 作用 |
|---|---|
| title() | 设置标题 |
| xlabel() | 设置横坐标名称 |
| ylabel() | 设置纵坐标名称 |
设置线条样式
python
plt.plot(
x,
y,
color='red',
linestyle='--',
linewidth=2,
marker='o'
)
plt.show()
参数说明:
| 参数 | 含义 |
|---|---|
| color | 颜色 |
| linestyle | 线型 |
| linewidth | 线宽 |
| marker | 数据点样式 |
常用颜色:
python
'red'
'blue'
'green'
'black'
'yellow'
常用线型:
python
'-' # 实线
'--' # 虚线
'-.' # 点划线
':' # 点线
常用标记:
python
'o' # 圆点
's' # 方块
'^' # 三角形
'*' # 星号
3. 多条曲线绘制
在实验对比中,经常需要将多个模型的结果绘制在同一张图中。
python
epoch = [1,2,3,4,5]
model1 = [60,70,78,83,85]
model2 = [58,72,80,86,88]
model3 = [65,78,85,90,93]
plt.plot(epoch, model1, label='1')
plt.plot(epoch, model2, label='2')
plt.plot(epoch, model3, label='3')
plt.legend()
plt.show()
图例设置
python
plt.legend(loc='upper left')
常见位置:
| 参数 | 位置 |
|---|---|
| upper left | 左上 |
| upper right | 右上 |
| lower left | 左下 |
| lower right | 右下 |
科研论文中一般放在:
python
upper right
或者:
python
upper left
避免遮挡曲线。
4. 曲线区域填充
为了增强可视化效果,可以对曲线下方区域进行填充。
python
plt.plot(x, y)
plt.fill_between(
x,
y,
alpha=0.3
)
plt.show()
其中:
python
alpha
表示透明度。
取值范围:
python
0 ~ 1
应用场景:
- Loss 曲线
- Accuracy 曲线
- 置信区间展示
- 误差带展示
5. 散点图(Scatter Plot)
散点图主要用于展示样本分布情况,中最常见的应用是:
- t-SNE 可视化
- PCA 可视化
- 特征聚类展示
基础示例
python
x = np.random.rand(100)
y = np.random.rand(100)
plt.scatter(x, y)
plt.show()
分类散点图
python
plt.scatter(x1, y1, label='1')
plt.scatter(x2, y2, label='2')
plt.scatter(x3, y3, label='3')
plt.legend()
plt.show()
用于观察:
- 类内紧凑程度
- 类间分离程度
6. 柱状图(Bar Plot)
柱状图主要用于比较不同类别数据之间的大小关系。
例如:
- 不同模型准确率比较
- 不同算法运行时间比较
python
models = ['1','2','3','4']
acc = [85,88,90,95]
plt.bar(models, acc)
plt.show()
显示数值标签
python
bars = plt.bar(models, acc)
for bar in bars:
plt.text(
bar.get_x()+0.2,
bar.get_height(),
str(bar.get_height())
)
plt.show()
7. 直方图(Histogram)
直方图用于展示数据分布情况。
python
data = np.random.randn(1000)
plt.hist(
data,
bins=30
)
plt.show()
参数:
python
bins=30
表示划分为30个区间。
概率密度直方图
python
plt.hist(
data,
bins=30,
density=True
)
常用于:
- 特征值统计分析
- 数据集分布分析
8. 箱线图(Box Plot)
箱线图用于展示数据统计特征。
python
data = [
np.random.randn(100),
np.random.randn(100)+1
]
plt.boxplot(data)
plt.show()
箱线图可以展示:
- 最大值
- 最小值
- 中位数
- 上四分位数
- 下四分位数
- 异常值
适用于:
- 多次实验稳定性分析
- 模型鲁棒性分析
9. 饼图(Pie Chart)
用于展示各部分所占比例。
python
sizes = [40,30,20,10]
labels = [
'1',
'2',
'3',
'4'
]
plt.pie(
sizes,
labels=labels,
autopct='%1.1f%%'
)
plt.show()
应用场景:
- 数据集类别比例
10. 热力图(Heatmap)
热力图利用颜色表示数值大小。
python
data = np.random.rand(10,10)
plt.imshow(
data,
cmap='jet'
)
plt.colorbar()
plt.show()
常用颜色映射
python
cmap='jet'
cmap='hot'
cmap='viridis'
cmap='coolwarm'
应用:
- 注意力权重可视化
- 特征响应图
- 卷积特征图
11. 混淆矩阵(Confusion Matrix)
分类任务中最重要的评价图之一。
python
cm = np.array([
[95,2,1],
[3,92,5],
[1,4,95]
])
plt.imshow(
cm,
cmap='Blues'
)
plt.colorbar()
plt.show()
通过混淆矩阵可以分析:
- 哪些类别容易混淆
- 分类器识别能力
- 类别间可分性
12. 三维图
python
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(
111,
projection='3d'
)
ax.scatter(x,y,z)
plt.show()
应用:
- 三维特征空间展示
- 三维 t-SNE
- 三维轨迹分析
13. 多子图绘制
论文中经常需要在同一张图中展示多个结果。
python
fig, ax = plt.subplots(
2,
2,
figsize=(10,8)
)
示例:
python
ax[0,0].plot(x,y)
ax[0,1].scatter(x,y)
ax[1,0].hist(data)
ax[1,1].bar(models,acc)
适用于:
- 多模型结果对比
14. 图像保存
论文投稿时必须掌握。
python
plt.savefig(
"result.png",
dpi=300,
bbox_inches='tight'
)
参数说明:
| 参数 | 作用 |
|---|---|
| dpi | 分辨率 |
| bbox_inches | 去除空白边缘 |
推荐保存格式
位图:
python
plt.savefig("result.png")
矢量图:
python
plt.savefig("result.svg")
python
plt.savefig("result.pdf")
科研推荐:
python
dpi = 600
15. 科研绘图常用全局配置
统一图像风格能够显著提升论文质量。
python
plt.rcParams['font.size'] = 14
plt.rcParams['figure.figsize'] = (8,6)
plt.rcParams['axes.unicode_minus'] = False
中文支持:
python
plt.rcParams['font.sans-serif'] = ['SimHei']