【深度学习基础模型】稀疏自编码器 (Sparse Autoencoders, SAE)详细理解并附实现代码。

【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model

【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model


文章目录

  • [【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model](#【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model)
  • [1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用](#1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用)
    • [1.1 SAE 原理](#1.1 SAE 原理)
    • [1.2 SAE 的主要特征:](#1.2 SAE 的主要特征:)
    • [1.3 SAE 的应用领域:](#1.3 SAE 的应用领域:)
  • [2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用](#2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用)
    • [2.1 SAE 模型的实现](#2.1 SAE 模型的实现)
    • [2.2 代码解释](#2.2 代码解释)
  • [3. 总结](#3. 总结)

参考地址:https://www.asimovinstitute.org/neural-network-zoo/

论文地址:https://www.cs.toronto.edu/\~ranzato/publications/ranzato-nips06.pdf

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用

1.1 SAE 原理

稀疏自编码器(Sparse Autoencoder, SAE)是一种特殊类型的自编码器,其设计目的是在编码过程中引入稀疏性,以鼓励网络学习更多的特征。与标准自编码器不同,SAE 的目标是将输入信息编码到一个比输入更高维的空间中,帮助提取多种小特征。

1.2 SAE 的主要特征:

  • 稀疏性:通过引入稀疏性约束,限制在某一时刻只有少数神经元被激活。这可以通过添加一个稀疏性损失项来实现,该项鼓励大多数神经元在输出中保持静默。
  • 网络结构:SAE 通常包含一个较小的中间层,但该中间层的激活仅通过部分神经元实现,从而生成丰富的特征表示。
  • 特征提取:SAE 适用于特征提取,尤其是在需要捕捉数据中细微差别的任务中,如图像分类、异常检测等。

1.3 SAE 的应用领域:

  • 图像处理:可以用于从遥感图像中提取细节特征。
  • 异常检测:在数据中识别异常点。
  • 生物信息学:提取基因表达数据中的重要特征。

在遥感领域,SAE 可以用于混合像元的分解,帮助识别和分离不同地物的光谱特征。

2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用

以下是一个简单的稀疏自编码器实现示例,展示如何在遥感图像的混合像元分解中应用 SAE。

2.1 SAE 模型的实现

csharp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# 定义稀疏自编码器模型
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重

    def forward(self, x):
        encoded = torch.relu(self.encoder(x))  # 编码过程
        decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
        return decoded, encoded

    def sparsity_loss(self, p_h):
        # 计算稀疏性损失
        return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                      (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))

# 生成模拟遥感图像数据 (64 维特征)
X = np.random.rand(1000, 64)  # 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)

# 创建数据加载器
dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型、优化器
input_size = 64
hidden_size = 32  # 隐藏层大小
sparsity_param = 0.05  # 稀疏性参数
beta = 1  # 稀疏性损失权重
sae = SparseAutoencoder(input_size=input_size, hidden_size=hidden_size, sparsity_param=sparsity_param, beta=beta)
optimizer = optim.Adam(sae.parameters(), lr=0.001)

# 训练 SAE 模型
num_epochs = 50
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        # 计算稀疏性损失
        p_h = torch.mean(encoded_data, dim=0)
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 使用训练好的模型进行特征提取
with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征

# 可视化原始数据和编码后的特征
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')

plt.subplot(1, 2, 2)
plt.title('Encoded Features')
plt.imshow(encoded_data[:10], aspect='auto', cmap='hot')

plt.show()

2.2 代码解释

1. 模型定义:

csharp 复制代码
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重
  • SparseAutoencoder 类定义了编码器和解码器的结构,同时定义了稀疏性参数和稀疏性损失权重。

2. 前向传播:

csharp 复制代码
def forward(self, x):
    encoded = torch.relu(self.encoder(x))  # 编码过程
    decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
    return decoded, encoded
  • 输入数据通过编码器和解码器处理,输出重构数据和编码数据。

3. 稀疏性损失计算:

csharp 复制代码
def sparsity_loss(self, p_h):
    return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                  (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))
  • 计算稀疏性损失,用于约束神经元的激活。

4. 数据生成:

csharp 复制代码
X = np.random.rand(1000, 64)  # 生成 1000 个样本,每个样本有 64 维光谱特征
  • 模拟生成随机的遥感光谱数据。

5. 数据加载器:

csharp 复制代码
dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 使用 DataLoader 创建批处理数据集。

6. 模型训练:

csharp 复制代码
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        p_h = torch.mean(encoded_data, dim=0)  # 计算激活的平均值
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
  • 在 50 个 epoch 内进行训练,计算重构损失并加入稀疏性损失,更新模型权重。

7. 特征提取:

csharp 复制代码
with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征
  • 在测试阶段,使用训练好的模型提取编码后的特征。

8. 可视化:

csharp 复制代码
plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')
  • 可视化原始数据和编码后的特征进行比较。

3. 总结

稀疏自编码器(SAE)是一种强大的特征学习模型,能够提取数据中的细微特征。通过引入稀疏性约束,SAE 有效地鼓励网络学习有用的特征表示。

在遥感领域,SAE 可用于混合像元的分解,帮助识别和分离不同地物的光谱特征。通过简单的 Python 实现,我们展示了如何使用 SAE 处理遥感数据,并可视化其效果。

相关推荐
DES 仿真实践家12 分钟前
【Day 11-N22】Python类(3)——Python的继承性、多继承、方法重写
开发语言·笔记·python
张较瘦_1 小时前
[论文阅读] 人工智能 + 软件工程 | 需求获取访谈中LLM生成跟进问题研究:来龙去脉与创新突破
论文阅读·人工智能
一 铭2 小时前
AI领域新趋势:从提示(Prompt)工程到上下文(Context)工程
人工智能·语言模型·大模型·llm·prompt
云泽野5 小时前
【Java|集合类】list遍历的6种方式
java·python·list
麻雀无能为力6 小时前
CAU数据挖掘实验 表分析数据插件
人工智能·数据挖掘·中国农业大学
时序之心6 小时前
时空数据挖掘五大革新方向详解篇!
人工智能·数据挖掘·论文·时间序列
IMPYLH6 小时前
Python 的内置函数 reversed
笔记·python
.30-06Springfield6 小时前
人工智能概念之七:集成学习思想(Bagging、Boosting、Stacking)
人工智能·算法·机器学习·集成学习
说私域7 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的超级文化符号构建路径研究
人工智能·小程序·开源
永洪科技7 小时前
永洪科技荣获商业智能品牌影响力奖,全力打造”AI+决策”引擎
大数据·人工智能·科技·数据分析·数据可视化·bi