【深度学习基础模型】稀疏自编码器 (Sparse Autoencoders, SAE)详细理解并附实现代码。

【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model

【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model


文章目录

  • [【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model](#【深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model)
  • [1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用](#1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用)
    • [1.1 SAE 原理](#1.1 SAE 原理)
    • [1.2 SAE 的主要特征:](#1.2 SAE 的主要特征:)
    • [1.3 SAE 的应用领域:](#1.3 SAE 的应用领域:)
  • [2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用](#2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用)
    • [2.1 SAE 模型的实现](#2.1 SAE 模型的实现)
    • [2.2 代码解释](#2.2 代码解释)
  • [3. 总结](#3. 总结)

参考地址:https://www.asimovinstitute.org/neural-network-zoo/

论文地址:https://www.cs.toronto.edu/\~ranzato/publications/ranzato-nips06.pdf

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用

1.1 SAE 原理

稀疏自编码器(Sparse Autoencoder, SAE)是一种特殊类型的自编码器,其设计目的是在编码过程中引入稀疏性,以鼓励网络学习更多的特征。与标准自编码器不同,SAE 的目标是将输入信息编码到一个比输入更高维的空间中,帮助提取多种小特征。

1.2 SAE 的主要特征:

  • 稀疏性:通过引入稀疏性约束,限制在某一时刻只有少数神经元被激活。这可以通过添加一个稀疏性损失项来实现,该项鼓励大多数神经元在输出中保持静默。
  • 网络结构:SAE 通常包含一个较小的中间层,但该中间层的激活仅通过部分神经元实现,从而生成丰富的特征表示。
  • 特征提取:SAE 适用于特征提取,尤其是在需要捕捉数据中细微差别的任务中,如图像分类、异常检测等。

1.3 SAE 的应用领域:

  • 图像处理:可以用于从遥感图像中提取细节特征。
  • 异常检测:在数据中识别异常点。
  • 生物信息学:提取基因表达数据中的重要特征。

在遥感领域,SAE 可以用于混合像元的分解,帮助识别和分离不同地物的光谱特征。

2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用

以下是一个简单的稀疏自编码器实现示例,展示如何在遥感图像的混合像元分解中应用 SAE。

2.1 SAE 模型的实现

csharp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# 定义稀疏自编码器模型
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重

    def forward(self, x):
        encoded = torch.relu(self.encoder(x))  # 编码过程
        decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
        return decoded, encoded

    def sparsity_loss(self, p_h):
        # 计算稀疏性损失
        return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                      (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))

# 生成模拟遥感图像数据 (64 维特征)
X = np.random.rand(1000, 64)  # 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)

# 创建数据加载器
dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型、优化器
input_size = 64
hidden_size = 32  # 隐藏层大小
sparsity_param = 0.05  # 稀疏性参数
beta = 1  # 稀疏性损失权重
sae = SparseAutoencoder(input_size=input_size, hidden_size=hidden_size, sparsity_param=sparsity_param, beta=beta)
optimizer = optim.Adam(sae.parameters(), lr=0.001)

# 训练 SAE 模型
num_epochs = 50
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        # 计算稀疏性损失
        p_h = torch.mean(encoded_data, dim=0)
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 使用训练好的模型进行特征提取
with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征

# 可视化原始数据和编码后的特征
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')

plt.subplot(1, 2, 2)
plt.title('Encoded Features')
plt.imshow(encoded_data[:10], aspect='auto', cmap='hot')

plt.show()

2.2 代码解释

1. 模型定义:

csharp 复制代码
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重
  • SparseAutoencoder 类定义了编码器和解码器的结构,同时定义了稀疏性参数和稀疏性损失权重。

2. 前向传播:

csharp 复制代码
def forward(self, x):
    encoded = torch.relu(self.encoder(x))  # 编码过程
    decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
    return decoded, encoded
  • 输入数据通过编码器和解码器处理,输出重构数据和编码数据。

3. 稀疏性损失计算:

csharp 复制代码
def sparsity_loss(self, p_h):
    return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                  (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))
  • 计算稀疏性损失,用于约束神经元的激活。

4. 数据生成:

csharp 复制代码
X = np.random.rand(1000, 64)  # 生成 1000 个样本,每个样本有 64 维光谱特征
  • 模拟生成随机的遥感光谱数据。

5. 数据加载器:

csharp 复制代码
dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 使用 DataLoader 创建批处理数据集。

6. 模型训练:

csharp 复制代码
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        p_h = torch.mean(encoded_data, dim=0)  # 计算激活的平均值
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
  • 在 50 个 epoch 内进行训练,计算重构损失并加入稀疏性损失,更新模型权重。

7. 特征提取:

csharp 复制代码
with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征
  • 在测试阶段,使用训练好的模型提取编码后的特征。

8. 可视化:

csharp 复制代码
plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')
  • 可视化原始数据和编码后的特征进行比较。

3. 总结

稀疏自编码器(SAE)是一种强大的特征学习模型,能够提取数据中的细微特征。通过引入稀疏性约束,SAE 有效地鼓励网络学习有用的特征表示。

在遥感领域,SAE 可用于混合像元的分解,帮助识别和分离不同地物的光谱特征。通过简单的 Python 实现,我们展示了如何使用 SAE 处理遥感数据,并可视化其效果。

相关推荐
Tiansan666613 小时前
郑州AI问答服务商崛起:专业团队如何重塑企业客服
人工智能·郑州ai问答服务商崛
DeniuHe13 小时前
sklearn 中所有交叉验证数据集划分方式完整总结
人工智能·python·sklearn
DeniuHe13 小时前
sklearn中不同交叉验证方法的场景适配
人工智能·python·sklearn
小新同学^O^13 小时前
简单学习 --> 指令微调
人工智能·学习·llm·指令微调
知识浅谈13 小时前
Transformer 中的 Q、K、V 到底是什么?怎么理解 Query、Key、Value?
人工智能·深度学习·transformer
名不经传的养虾人13 小时前
从0到1:企业级AI项目迭代日记 Vol.36|临时方案下线,网关区分负载,用量穿透链路——这一周全是“归位”
人工智能·ai编程·ai工作流·企业ai·多agent协作
風清掦13 小时前
【STM32学习笔记-14】WDG看门狗 - 14.2 WWDG窗口看门狗
笔记·stm32·单片机·嵌入式硬件·学习·fpga开发
小程故事多_8013 小时前
拆解Hermes Agent技术架构,会自我迭代的开源智能体如何突破AI传统局限
人工智能·架构·开源
黎阳之光13 小时前
数智透明·安全兜底|黎阳之光透明矿山,AI+数字孪生守护矿山生命线
人工智能·物联网·算法·安全·数字孪生
Bigger13 小时前
mini-cc 的 MCP 协议:给 AI 装个 USB-C 接口
人工智能·ai编程·claude