基于卷积神经网络的高光谱分类 CNN
- 混合光谱HybridSN
-
- [传统的2-D CNN](#传统的2-D CNN)
- [混合光谱3-D CNN](#混合光谱3-D CNN)
- 操作步骤
混合光谱HybridSN
传统的2-D CNN
传统的2-D CNN方法在处理HSI时往往只考虑了光谱信息,而忽略了空间信息的重要性。
混合光谱3-D CNN
HybridSN通过引入3-D CNN的思想,将
光谱信息
和空间信息
结合在一起进行特征学习和分类。具体来说,HybridSN在网络结构中设计了专门处理光谱信息的卷积层和处理空间信息的卷积层,同时考虑了各个波段之间的相关性和空间上的局部特征
。这种混合光谱的设计能够更全面地捕捉HSI图像中的特征,提高分类性能并减少信息损失。
操作步骤
环境:Jupyter Notebook
前言(准备)
获取数据以及引入基本的库函数
python
# 下载Indian Pines数据集的纠正版本和地面真实值数据集
! wget http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat
! wget http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat
# 安装Python库spectral,用于处理和分析遥感数据
! pip install spectral
导入相关的包
python
import numpy as np # 导入NumPy库,用于处理数组和矩阵运算
import matplotlib.pyplot as plt # 导入Matplotlib库,用于绘制图表和可视化数据
import scipy.io as sio # 导入SciPy库的io模块,用于读取和写入MATLAB文件格式
from sklearn.decomposition import PCA # 导入PCA算法,用于数据降维
from sklearn.model_selection import train_test_split # 导入train_test_split函数,用于划分训练集和测试集
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score # 导入评估指标函数
import spectral # 导入Spectral Python库,用于处理和分析遥感数据
import torch # 导入PyTorch库,用于构建神经网络模型
import torchvision # 导入PyTorch的视觉库,用于处理图像数据
import torch.nn as nn # 导入PyTorch的神经网络模块
import torch.nn.functional as F # 导入PyTorch的神经网络函数
import torch.optim as optim # 导入PyTorch的优化器模块
创建模型
模型网络结构
三维处理
- conv1:(1, 30, 25, 25), 8个 7x3x3 的卷积核 ==>(8, 24, 23, 23)
- conv2:(8, 24, 23, 23), 16个 5x3x3 的卷积核 ==>(16, 20, 21, 21)
- conv3:(16, 20, 21, 21),32个 3x3x3 的卷积核 ==>(32, 18, 19, 19)
具体来说,对于输入数据的维度为(1, 30, 25, 25),其中1表示通道数,30表示光谱维度,25表示空间维度(高度和宽度)。
在你提供的三维卷积部分中,"conv1"、"conv2"和"conv3"分别表示三个卷积层,每个卷积层通过不同大小的卷积核对输入数据进行卷积操作,得到相应的输出特征图。这些卷积层通常会被跟随着激活函数、池化操作等其他层,构成一个完整的深度学习模型。
二维处理
接下来要进行二维卷积,因此把前面的 32*18 reshape 一下,得到 (576, 19, 19)
二维卷积:(576, 19, 19) 64个 3x3 的卷积核,得到 (64, 17, 17)
1.
Reshape操作:
将(32, 18)的特征矩阵reshape为(576, 19, 19)的三维特征矩阵,其中576表示特征的数量,19表示空间维度。2.
二维卷积操作:
使用64个3x3的二维卷积核对(576, 19, 19)的三维特征矩阵进行卷积操作。每个3x3的卷积核在三维特征矩阵上进行滑动操作,计算每个位置的卷积结果,最终得到64个输出通道的特征图。3.
输出特征图:
经过64个3x3的卷积核的卷积操作后,我们得到了一个维度为(64, 17, 17)的输出特征图,其中64表示卷积核的数量,17表示空间维度。这个输出特征图将作为下一层神经网络的输入,继续进行后续的处理和学习。
一维处理
接下来是一个 flatten 操作,变为 18496 维的向量
接下来依次为256,128节点的全连接层,都使用比例为0.4的 Dropout,最后输出为 16 个节点,是最终的分类类别数
1.
Flatten操作:
将(64, 17, 17)的特征图展平为一个18496维的向量。2.
全连接层1(256节点):
将18496维的向量输入到一个拥有256个节点的全连接层中,进行权重计算和激活操作。3.
Dropout操作(比例为0.4):
在全连接层1的输出上应用Dropout操作,随机丢弃40%的神经元,以防止过拟合。4.
全连接层2(128节点):
将经过Dropout操作后的输出输入到一个拥有128个节点的全连接层中,进行权重计算和激活操作。5.
Dropout操作(比例为0.4):
在全连接层2的输出上再次应用Dropout操作,同样丢弃40%的神经元。6.
输出层(16个节点):
最后将经过Dropout操作后的输出输入到一个拥有16个节点的全连接层中,这个全连接层的输出就是最终的分类类别数。
代码
python
class_num = 16
class HybridSN(nn.Module):
def __init__(self):
super(HybridSN, self).__init__()
# 3D卷积层
self.conv_3d = nn.Sequential(
nn.Conv3d(1, 8, (7, 3, 3)), # 输入通道数为1,输出通道数为8,卷积核大小为(7, 3, 3)
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活函数
nn.Conv3d(8, 16, (5, 3, 3)), # 输入通道数为8,输出通道数为16,卷积核大小为(5, 3, 3)
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活函数
nn.Conv3d(16, 32, (3, 3, 3)), # 输入通道数为16,输出通道数为32,卷积核大小为(3, 3, 3)
nn.LeakyReLU(0.2, inplace=True) # LeakyReLU激活函数
)
# 2D卷积层
self.conv_2d = nn.Sequential(
nn.Conv2d(576, 64, (3, 3)), # 输入通道数为576,输出通道数为64,卷积核大小为(3, 3)
nn.LeakyReLU(0.2, inplace=True) # LeakyReLU激活函数
)
# 全连接层
self.linear = nn.Sequential(
nn.Linear(18496, 256), # 输入特征维度为18496,输出特征维度为256
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活函数
nn.Dropout(0.4), # Dropout操作,丢弃比例为0.4
nn.Linear(256, 128), # 输入特征维度为256,输出特征维度为128
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活函数
nn.Dropout(0.4), # Dropout操作,丢弃比例为0.4
nn.Linear(128, class_num), # 最终输出层,输出类别数为class_num
nn.LogSoftmax(dim=1) # LogSoftmax函数,用于多分类问题的输出
)
def forward(self, x):
x = self.conv_3d(x) # 3D卷积操作
x = x.view(-1, x.shape[1] * x.shape[2], x.shape[3], x.shape[4]) # reshape操作
x = self.conv_2d(x) # 2D卷积操作
x = x.view(x.size(0), -1) # flatten操作
x = self.linear(x) # 全连接层操作
return x
以上代码是一个名为HybridSN的神经网络模型类,包含了3D卷积层、2D卷积层和全连接层。在forward方法中,定义了模型的前向传播过程,包括卷积操作、reshape操作、激活函数、Dropout操作和LogSoftmax函数。
测试
python
#测试网络结构是否通
def test_net():
# 随机输入
x = torch.randn(1, 1, 30, 25, 25)
net = HybridSN()
y = net(x)
print(y.shape)
test_net()