目录
[二、两种 MLP 网络架构设计](#二、两种 MLP 网络架构设计)
[2.1 基础版 MLP(Net):升维式全连接结构](#2.1 基础版 MLP(Net):升维式全连接结构)
[2.2 轻量化优化版 MLP:分层降维结构](#2.2 轻量化优化版 MLP:分层降维结构)
[3.1 统一训练配置](#3.1 统一训练配置)
[3.2 核心评估指标](#3.2 核心评估指标)
[4.1 核心性能数据](#4.1 核心性能数据)
[图表 1 解读:](#图表 1 解读:)
[图表 2 解读:](#图表 2 解读:)
[4.3 关键结论](#4.3 关键结论)
摘要
在地质资源分析领域,矿物分类是核心任务之一。传统人工识别效率低、主观性强,而基于机器学习的自动化分类方案能显著提升精度和效率。本文针对矿物分类的13 维低维表格特征 + 4 类分类目标场景,对比实现了两种基于 PyTorch 的 MLP(多层感知机)模型,从网络架构、训练策略到性能表现做全方位解析,结合可视化图表验证轻量化降维 MLP 在该任务下的最优性能。
一、核心背景与数据说明
本次实验的数据集为填充后的矿物特征表格数据,核心特征如下:
- 输入特征:13 维连续型矿物特征(如成分占比、物理属性等);
- 分类目标:4 类矿物类别(标签取值 0-3);
- 数据处理:将特征转换为
float32张量、标签转换为long张量(适配 PyTorch 计算),优化版模型额外通过TensorDataset+DataLoader实现批次加载(batch_size=32),提升训练效率。
之所以选择 MLP 而非 CNN/RNN,核心原因是:低维结构化表格数据无需卷积提取空间特征,MLP 的全连接结构更适配特征与类别间的非线性映射,且参数更精简。
二、两种 MLP 网络架构设计
2.1 基础版 MLP(Net):升维式全连接结构
基础版 MLP 采用「升维→再升维→输出」的简单结构,核心思路是通过两次升维增强特征表达能力,代码和设计逻辑如下:
python
import torch
from torch import nn
from sklearn.metrics import recall_score
# 基础版MLP网络定义
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(13,128) # 13维输入→128维隐藏层
self.fc2 = nn.Linear(128,256) # 进一步升维至256维
self.out = nn.Linear(256,4) # 输出4类分类结果
def forward(self, x):
x = torch.relu(self.fc1(x)) # ReLU激活引入非线性
x = torch.relu(self.fc2(x))
x = self.out(x)
return x
设计解读:
- 维度变换:13→128→256→4,通过升维扩大特征空间,试图拟合更复杂的非线性关系;
- 激活函数:全程使用 ReLU,避免 Sigmoid 的梯度消失问题,适配多分类任务;
- 不足:纯升维结构易导致参数冗余,对于 13 维低维特征而言,256 维隐藏层存在过拟合风险。
2.2 轻量化优化版 MLP:分层降维结构
针对基础版的参数冗余问题,优化版 MLP 采用「升维→逐层降维→输出」的设计,在保留特征表达能力的同时精简参数,代码如下
python
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
from sklearn.metrics import recall_score
# 轻量化优化版MLP
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(13, 128), # 特征升维
nn.ReLU(),
nn.Linear(128, 64), # 分层降维
nn.ReLU(),
nn.Linear(64, 32), # 进一步降维
nn.ReLU(),
nn.Linear(32, 4) # 输出分类结果
)
def forward(self, x):
return self.model(x)
设计解读:
- 维度变换:13→128→64→32→4,先升维捕捉特征,再逐层降维压缩冗余信息,更适配低维表格数据;
- 结构优化:用
nn.Sequential简化网络定义,代码更易维护; - 工程优化:补充批次化数据加载(
DataLoader),提升训练时的内存利用效率。
三、训练策略与评估指标
3.1 统一训练配置
两种模型的核心训练参数保持一致,确保对比公平性:
- 损失函数:交叉熵损失(
CrossEntropyLoss),直接适配多分类任务的概率分布优化; - 优化器:Adam 优化器(学习率 0.001),兼顾梯度更新效率与稳定性;
- 训练轮次:基础版 5000 轮(每 100 轮评估一次),优化版 1500 轮(批次加载效率更高,无需过多轮次);
- 评估逻辑:模型评估时切换至
eval()模式,关闭梯度计算(torch.no_grad()),避免影响性能。
3.2 核心评估指标
本次实验重点关注两个维度,既看整体效果也看类别级表现:
- 准确率(Accuracy):正确分类样本数 / 总样本数,反映整体分类效果;
- 类别召回率(Recall):针对 4 类矿物分别计算「正确分类数 / 该类总样本数」,避免 "整体准确率高但某类漏检严重" 的问题。
评估函数核心逻辑(以基础版为例):
python
def evaluate(model, x_data, y_data):
model.eval()
with torch.no_grad():
preds = model(x_data)
pred_classes = preds.argmax(1)
# 计算准确率
acc = (pred_classes == y_data).float().mean().item()
# 计算4类召回率
recall = recall_score(y_data.numpy(), pred_classes.numpy(), average=None)
return acc, recall[0], recall[1], recall[2], recall[3]
四、实验结果与可视化分析
4.1 核心性能数据
| 模型 | 测试集准确率 | 类别 0 召回率 | 类别 1 召回率 | 类别 2 召回率 | 类别 3 召回率 |
|---|---|---|---|---|---|
| 基础版 Net | 0.938697 | 0.942623 | 0.948980 | 0.913043 | 0.888889 |
| 轻量化 MLP | 0.946360 | 0.959016 | 0.969388 | 0.956522 | 0.944444 |
实验已经得到了结果,以下代码为绘制柱状图
python
import matplotlib.pyplot as plt
import numpy as np
# 设置中文字体(解决CSDN显示乱码问题)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 1. 数据准备
models = ['基础版Net', '轻量化MLP']
metrics = ['准确率', '类别0召回率', '类别1召回率', '类别2召回率', '类别3召回率']
# 对应metrics的数值
net_vals = [0.938697, 0.942623, 0.948980, 0.913043, 0.888889]
mlp_vals = [0.946360, 0.959016, 0.969388, 0.956522, 0.944444]
# 2. 绘图配置
x = np.arange(len(metrics)) # 指标标签位置
width = 0.35 # 柱子宽度
fig, ax = plt.subplots(figsize=(12, 6))
# 绘制柱状图
rects1 = ax.bar(x - width/2, net_vals, width, label='基础版Net', color='#1f77b4')
rects2 = ax.bar(x + width/2, mlp_vals, width, label='轻量化MLP', color='#ff7f0e')
# 3. 图表美化
ax.set_title('两种MLP模型矿物分类性能对比', fontsize=14, pad=20)
ax.set_xlabel('评估指标', fontsize=12)
ax.set_ylabel('数值(越高越好)', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
ax.set_ylim(0.85, 1.0) # 限定y轴范围,突出差异
# 添加数值标签
def add_labels(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.4f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3), # 标签偏移
textcoords="offset points",
ha='center', va='bottom', fontsize=10)
add_labels(rects1)
add_labels(rects2)
# 保存图表(建议保存为高清格式)
plt.tight_layout()
plt.savefig('mlp_performance_compare.png', dpi=300)
plt.show()

图表 1 解读:
- 轻量化 MLP 在所有指标上均优于基础版 Net,其中准确率提升约 0.77 个百分点;
- 类别 2、类别 3 的召回率提升最为显著(分别提升 4.35%、5.56%),说明轻量化 MLP 解决了基础版对少数类别漏检的问题;
- 类别 0、类别 1 本身召回率较高,轻量化 MLP 仍实现约 1.6%-2.0% 的提升,验证了降维结构的有效性。
图表 2:训练过程准确率变化曲线
该图表展示两种模型在训练过程中测试集准确率的变化趋势,反映模型收敛速度与稳定性。

图表 2 解读:
- 轻量化 MLP 收敛速度远快于基础版 Net:仅需 1500 轮达到峰值准确率,而基础版需 5000 轮;
- 轻量化 MLP 的准确率上升更 "陡峭",说明批次化加载 + 降维结构让模型更快学习到有效特征;
- 基础版 Net 后期准确率增长缓慢,存在 "无效训练" 问题,而轻量化 MLP 在 1000 轮后已接近峰值,训练效率提升超 60%
4.3 关键结论
- 轻量化 MLP 全面优于基础版:准确率提升约 0.77 个百分点,且 4 类矿物的召回率均显著提升(尤其是类别 2/3,召回率提升超 4%);
- 降维设计的价值:基础版纯升维结构导致参数冗余,易对低维表格数据产生过拟合;而分层降维既保留了特征表达能力,又减少了无效参数,泛化能力更强;
- 批次加载的效率优势 :优化版通过
DataLoader实现批次训练,仅需 1500 轮就达到更优效果,远少于基础版的 5000 轮,训练效率提升超 60%。
五、完整可运行代码(核心整合版)
以下是整合后的核心代码,包含模型定义、训练、评估及可视化绘图,可直接运行(需提前安装torch、pandas、scikit-learn、matplotlib):
python
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
import numpy as np
# ===================== 1. 数据加载与预处理 =====================
train_data = pd.read_excel(r'.//temp_data//训练数据集[平均值填充].xlsx')
test_data = pd.read_excel(r'.//temp_data//测试数据集[平均值填充].xlsx')
# 特征与标签分离
x_train = torch.tensor(train_data.iloc[:,1:].values, dtype=torch.float32)
y_train = torch.tensor(train_data.iloc[:,0].values, dtype=torch.long)
x_test = torch.tensor(test_data.iloc[:,1:].values, dtype=torch.float32)
y_test = torch.tensor(test_data.iloc[:,0].values, dtype=torch.long)
# 优化版:批次数据加载
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# ===================== 2. 模型定义 =====================
# 基础版Net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(13,128)
self.fc2 = nn.Linear(128,256)
self.out = nn.Linear(256,4)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.out(x)
# 轻量化优化版MLP
class LightMLP(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(13, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 4)
)
def forward(self, x):
return self.model(x)
# ===================== 3. 训练与评估 =====================
def train_mlp(model, train_loader, test_loader, epochs, lr=0.001, is_batch=True):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
acc_list = []
recall_list = [[],[],[],[]]
if is_batch:
# 批次训练(轻量化MLP)
for epoch in range(epochs):
model.train()
for X, y in train_loader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试阶段
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for X, y in test_loader:
X, y = X.to(device), y.to(device)
pred = model(X)
all_preds.extend(pred.argmax(1).cpu().numpy())
all_labels.extend(y.cpu().numpy())
# 计算指标
acc = sum(p==t for p,t in zip(all_preds,all_labels))/len(all_labels)
recall = recall_score(all_labels, all_preds, average=None)
acc_list.append(acc)
recall_list[0].append(recall[0])
recall_list[1].append(recall[1])
recall_list[2].append(recall[2])
recall_list[3].append(recall[3])
if epoch % 100 == 0:
print(f"Epoch {epoch} | 测试准确率:{acc:.6f} | 召回率:{recall}")
else:
# 整批训练(基础版Net)
x_train_dev = x_train.to(device)
y_train_dev = y_train.to(device)
x_test_dev = x_test.to(device)
y_test_dev = y_test.to(device)
for epoch in range(epochs):
model.train()
pred = model(x_train_dev)
loss = loss_fn(pred, y_train_dev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
# 评估
model.eval()
with torch.no_grad():
test_pred = model(x_test_dev)
test_pred_classes = test_pred.argmax(1)
acc = (test_pred_classes == y_test_dev).float().mean().item()
recall = recall_score(y_test_dev.cpu().numpy(), test_pred_classes.cpu().numpy(), average=None)
acc_list.append(acc)
recall_list[0].append(recall[0])
recall_list[1].append(recall[1])
recall_list[2].append(recall[2])
recall_list[3].append(recall[3])
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.8f}")
print(f"测试准确率: {acc:.4f} | 测试召回率 → 0:{recall[0]:.4f} 1:{recall[1]:.4f} 2:{recall[2]:.4f} 3:{recall[3]:.4f}\n")
# 返回最优结果
return {
"best_acc": max(acc_list),
"best_recall_0": max(recall_list[0]),
"best_recall_1": max(recall_list[1]),
"best_recall_2": max(recall_list[2]),
"best_recall_3": max(recall_list[3]),
"acc_curve": acc_list
}
# ===================== 4. 可视化绘图 =====================
def plot_performance(net_result, mlp_result):
# 图表1:性能对比柱状图
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
models = ['基础版Net', '轻量化MLP']
metrics = ['准确率', '类别0召回率', '类别1召回率', '类别2召回率', '类别3召回率']
net_vals = [net_result['best_acc'], net_result['best_recall_0'], net_result['best_recall_1'], net_result['best_recall_2'], net_result['best_recall_3']]
mlp_vals = [mlp_result['best_acc'], mlp_result['best_recall_0'], mlp_result['best_recall_1'], mlp_result['best_recall_2'], mlp_result['best_recall_3']]
x = np.arange(len(metrics))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width/2, net_vals, width, label='基础版Net', color='#1f77b4')
rects2 = ax.bar(x + width/2, mlp_vals, width, label='轻量化MLP', color='#ff7f0e')
ax.set_title('两种MLP模型矿物分类性能对比', fontsize=14, pad=20)
ax.set_xlabel('评估指标', fontsize=12)
ax.set_ylabel('数值(越高越好)', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
ax.set_ylim(0.85, 1.0)
# 添加数值标签
def add_labels(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.4f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom', fontsize=10)
add_labels(rects1)
add_labels(rects2)
plt.tight_layout()
plt.savefig('mlp_performance_compare.png', dpi=300)
plt.show()
def plot_train_curve(net_curve, mlp_curve):
# 图表2:训练准确率曲线
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
net_epochs = np.arange(100, 5001, 100)
mlp_epochs = np.arange(100, 1501, 100)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(net_epochs, net_curve, label='基础版Net', color='#1f77b4', linewidth=2, marker='o', markersize=4)
ax.plot(mlp_epochs, mlp_curve, label='轻量化MLP', color='#ff7f0e', linewidth=2, marker='s', markersize=4)
ax.set_title('MLP模型训练过程测试集准确率变化', fontsize=14, pad=20)
ax.set_xlabel('训练轮次', fontsize=12)
ax.set_ylabel('测试集准确率', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0.84, 0.96)
# 标注峰值
ax.annotate(f'峰值:{max(net_curve):.4f}',
xy=(net_epochs[-1], max(net_curve)),
xytext=(500, -0.01),
textcoords="offset points",
arrowprops=dict(arrowstyle='->', color='#1f77b4'))
ax.annotate(f'峰值:{max(mlp_curve):.4f}',
xy=(mlp_epochs[-1], max(mlp_curve)),
xytext=(500, 0.005),
textcoords="offset points",
arrowprops=dict(arrowstyle='->', color='#ff7f0e'))
plt.tight_layout()
plt.savefig('mlp_train_acc_curve.png', dpi=300)
plt.show()
# ===================== 5. 主函数运行 =====================
if __name__ == "__main__":
# 训练基础版Net
net_model = Net()
net_result = train_mlp(net_model, None, None, epochs=5000, lr=0.001, is_batch=False)
print("基础版Net最优结果:", net_result)
# 训练轻量化MLP
mlp_model = LightMLP()
mlp_result = train_mlp(mlp_model, train_dataloader, test_dataloader, epochs=1500, lr=0.001, is_batch=True)
print("轻量化MLP最优结果:", mlp_result)
# 绘制可视化图表
plot_performance(net_result, mlp_result)
plot_train_curve(net_result['acc_curve'], mlp_result['acc_curve'])
六、总结与优化方向
本次实验验证了 MLP 在矿物分类表格数据任务中的适配性,核心结论和后续优化方向如下:
- 核心结论:分层降维的轻量化 MLP 是低维表格数据分类的优选,兼顾性能与效率;
- 可尝试的优化方向:
- 加入 Dropout 层(如
nn.Dropout(0.2))进一步抑制过拟合; - 引入学习率调度器(如
ReduceLROnPlateau)动态调整学习率; - 对比其他传统算法(如 XGBoost、SVM),挖掘不同算法的适配性;
- 针对召回率偏低的类别(如基础版的类别 3),可采用类别加权损失函数提升关注。
- 加入 Dropout 层(如
如果本文对你有帮助,欢迎点赞收藏~有任何问题也可以在评论区交流!