前言
系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。
降雨预测作为气象学和水文学领域的重要研究内容,对于农业生产、城市水资源管理、灾害预警等方面具有举足轻重的意义。传统的降雨预测方法主要依赖于统计模型和物理模型,但这些方法在处理非线性、高维度的时序数据时往往显得力不从心。近年来,随着深度学习技术的飞速发展,尤其是卷积神经网络(CNN)在各个领域的广泛应用,为降雨预测提供了新的思路和手段。
一维卷积神经网络(Conv1D)作为CNN的一种变体,特别适用于处理时序数据。它能够自动提取数据中的特征,并通过逐层卷积和池化操作,有效捕捉时序数据中的复杂模式和长期依赖关系。因此,将一维卷积神经网络应用于降雨多变量时序分类任务,即预测明日是否降雨,具有重要的研究价值和应用前景。
本研究旨在探索基于一维卷积神经网络的降雨多变量时序分类方法,通过构建合适的网络模型,充分利用历史降雨数据中的时序信息和多变量特征,提高降雨预测的准确性和稳定性。同时,本研究还将对模型的可解释性进行探讨,以期为降雨预测的实际应用提供更有力的支持。
目录
- [1. 数据集介绍](#1. 数据集介绍)
- [2. 数据可视化](#2. 数据可视化)
-
- [2.1 检查数据是否缺失](#2.1 检查数据是否缺失)
- [2.2 检查数据是否平衡](#2.2 检查数据是否平衡)
- [3. 数据预处理](#3. 数据预处理)
-
- [3.1 数据清理------填补缺失值](#3.1 数据清理——填补缺失值)
-
- [3.1.1 分类变量](#3.1.1 分类变量)
- [3.1.2 数值变量](#3.1.2 数值变量)
- [3.2 异常检测](#3.2 异常检测)
-
- [3.2.1 数值变量异常检测](#3.2.1 数值变量异常检测)
- [3.2.2 异常值离群值处理](#3.2.2 异常值离群值处理)
- [4. 特征工程](#4. 特征工程)
-
- [4.1 Label编码和One-hot编码](#4.1 Label编码和One-hot编码)
- [4.2 特征缩放(归一化)](#4.2 特征缩放(归一化))
- [4.3 构建时间序列数据](#4.3 构建时间序列数据)
- [4.4 数据集过采样 SMOTE](#4.4 数据集过采样 SMOTE)
- [4.5 数据集划分](#4.5 数据集划分)
- [4.6 数据集张量](#4.6 数据集张量)
- [5. 构建时序模型(TSC)](#5. 构建时序模型(TSC))
-
- [5.1 构建TimeSeriesCNN模型](#5.1 构建TimeSeriesCNN模型)
- [5.2 定义模型、损失函数与优化器](#5.2 定义模型、损失函数与优化器)
- [5.3 模型概要](#5.3 模型概要)
- [6. 模型训练与可视化](#6. 模型训练与可视化)
-
- [6.1 定义训练与评估函数](#6.1 定义训练与评估函数)
- [6.2 绘制损失与准确率曲线](#6.2 绘制损失与准确率曲线)
- [7. 模型评估与可视化](#7. 模型评估与可视化)
-
- [7.1 构建预测函数](#7.1 构建预测函数)
- [7.2 混淆矩阵](#7.2 混淆矩阵)
- [7.3 ROC_AUC曲线](#7.3 ROC_AUC曲线)
- [7.4 分类报告](#7.4 分类报告)
1. 数据集介绍
该数据集包括澳大利亚许多地点约 10 年的每日天气观测数据。RainTomorrow 是要预测的目标变量。它回答了一个关键问题:第二天会下雨吗? 是或否)。如果当天的降雨量达到或超过 1 毫米,则此列标记为 "是"。下载🔗
首先让我们导入必要的库和数据集
python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchinfo import summary
np.random.seed(0)
python
data = pd.read_csv("weatherAUS.csv")
print(data.head().T)
python
0 1 2 3 4
Date 2008-12-01 2008-12-02 2008-12-03 2008-12-04 2008-12-05
Location Albury Albury Albury Albury Albury
MinTemp 13.4 7.4 12.9 9.2 17.5
MaxTemp 22.9 25.1 25.7 28.0 32.3
Rainfall 0.6 0.0 0.0 0.0 1.0
Evaporation NaN NaN NaN NaN NaN
Sunshine NaN NaN NaN NaN NaN
WindGustDir W WNW WSW NE W
WindGustSpeed 44.0 44.0 46.0 24.0 41.0
WindDir9am W NNW W SE ENE
WindDir3pm WNW WSW WSW E NW
WindSpeed9am 20.0 4.0 19.0 11.0 7.0
WindSpeed3pm 24.0 22.0 26.0 9.0 20.0
Humidity9am 71.0 44.0 38.0 45.0 82.0
Humidity3pm 22.0 25.0 30.0 16.0 33.0
Pressure9am 1007.7 1010.6 1007.6 1017.6 1010.8
Pressure3pm 1007.1 1007.8 1008.7 1012.8 1006.0
Cloud9am 8.0 NaN NaN NaN 7.0
Cloud3pm NaN NaN 2.0 NaN 8.0
Temp9am 16.9 17.2 21.0 18.1 17.8
Temp3pm 21.8 24.3 23.2 26.5 29.7
RainToday No No No No No
RainTomorrow No No No No No
该数据集包含澳大利亚各地约 10 年的每日天气观测数据。观测数据来自众多气象站。在本项目中,我将利用这些数据预测第二天是否会下雨。包括目标变量 "RainTomorrow "在内的 23 个属性表明第二天是否会下雨。
2. 数据可视化
2.1 检查数据是否缺失
.info()
方法打印有关DataFrame
的信息,包括索引 dtype
和列、非 null
值以及内存使用情况。
python
data.info()
python
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 23 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Date 145460 non-null object
1 Location 145460 non-null object
2 MinTemp 143975 non-null float64
3 MaxTemp 144199 non-null float64
4 Rainfall 142199 non-null float64
5 Evaporation 82670 non-null float64
6 Sunshine 75625 non-null float64
7 WindGustDir 135134 non-null object
8 WindGustSpeed 135197 non-null float64
9 WindDir9am 134894 non-null object
10 WindDir3pm 141232 non-null object
11 WindSpeed9am 143693 non-null float64
12 WindSpeed3pm 142398 non-null float64
13 Humidity9am 142806 non-null float64
14 Humidity3pm 140953 non-null float64
15 Pressure9am 130395 non-null float64
16 Pressure3pm 130432 non-null float64
17 Cloud9am 89572 non-null float64
18 Cloud3pm 86102 non-null float64
19 Temp9am 143693 non-null float64
20 Temp3pm 141851 non-null float64
21 RainToday 142199 non-null object
22 RainTomorrow 142193 non-null object
dtypes: float64(16), object(7)
memory usage: 25.5+ MB
python
sns.heatmap(data.isnull(), cbar=False, cmap='PuBu')
我们可以很明显的观察到数据集中有缺失值,数据集中包括数值和分类值
2.2 检查数据是否平衡
接下来,我们将检查数据集是不平衡还是平衡的。如果数据集是不平衡的,我们就需要对大多数数据进行降采样或对少数数据进行超采样,以达到平衡。
python
proportion = data.RainTomorrow.value_counts(normalize = True)
plt.style.use('bmh')
plt.figure(figsize=(8,5))
plt.bar(proportion.index, proportion.values, color=['lightsteelblue', 'slategrey'])
plt.xlabel('RainTomorrow')
plt.ylabel('Proportion')
plt.title('Proportion of RainTomorrow')
plt.xticks(rotation=45, ha="right")
plt.show()
3. 数据预处理
现在,我将把日期解析为时间数据类型
python
print(type(data['RainTomorrow'].iloc[0]),type(data['Date'].iloc[0]))
# Let's convert the data type of timestamp column to datatime format
data['Date'] = pd.to_datetime(data['Date'])
print(type(data['RainTomorrow'].iloc[0]),type(data['Date'].iloc[0]))
print(data.shape)
python
<class 'str'> <class 'str'>
<class 'str'> <class 'pandas._libs.tslibs.timestamps.Timestamp'>
(145460, 23)
3.1 数据清理------填补缺失值
3.1.1 分类变量
用列值的众数 mode
填补缺失值
python
# Selecting columns of categorical variables
object_columns = data.select_dtypes(include=['object']).columns.tolist()
# Missing values in categorical variables
data[object_columns].isnull().sum()
python
Location 0
WindGustDir 10326
WindDir9am 10566
WindDir3pm 4228
RainToday 3261
RainTomorrow 3267
dtype: int64
python
# Filling missing values with mode of the column in value
for col in object_columns:
data.fillna({col: data[col].mode()[0]}, inplace=True)
# Counting missing values
data[object_columns].isnull().sum()
python
Location 0
WindGustDir 0
WindDir9am 0
WindDir3pm 0
RainToday 0
RainTomorrow 0
dtype: int64
3.1.2 数值变量
用列值的中位数 median
填补缺失值
python
# Selecting columns of neumeric variables
neumeric_columns = data.select_dtypes(include=['float64']).columns.tolist()
# Missing values in numeric variables
data[neumeric_columns].isnull().sum()
python
MinTemp 1485
MaxTemp 1261
Rainfall 3261
Evaporation 62790
Sunshine 69835
WindGustSpeed 10263
WindSpeed9am 1767
WindSpeed3pm 3062
Humidity9am 2654
Humidity3pm 4507
Pressure9am 15065
Pressure3pm 15028
Cloud9am 55888
Cloud3pm 59358
Temp9am 1767
Temp3pm 3609
dtype: int64
python
# Filling missing values with median of the column in value
for col in neumeric_columns:
data.fillna({col: data[col].median()}, inplace=True)
# Counting missing values
data[neumeric_columns].isnull().sum()
python
MinTemp 0
MaxTemp 0
Rainfall 0
Evaporation 0
Sunshine 0
WindGustSpeed 0
WindSpeed9am 0
WindSpeed3pm 0
Humidity9am 0
Humidity3pm 0
Pressure9am 0
Pressure3pm 0
Cloud9am 0
Cloud3pm 0
Temp9am 0
Temp3pm 0
dtype: int64
3.2 异常检测
3.2.1 数值变量异常检测
在统计学和数据科学中,识别和处理异常值(outliers)是一个重要的步骤,因为它们可能会对分析产生重大影响。异常值是指那些与其他数据点显著不同的观测值。使用四分位数范围(IQR)来检测异常值是一种常见且有效的方法。
python
k = 1.5
python
# Initialize the figure with a logarithmic x axis
fig, axes = plt.subplots(nrows=len(neumeric_columns), ncols=1, figsize=(20, 20))
axes = axes.flatten()
for i, feature in enumerate(neumeric_columns):
# Plot the orbital period with horizontal boxes
sns.boxenplot(x=data[feature], ax=axes[i], color='slategrey', linecolor='grey', orient='h')
# Detecting outliers with IQR
Q1 = data[feature].quantile(0.25)
Q3 = data[feature].quantile(0.75)
IQR = Q3 - Q1
lower = Q1 - k * IQR
upper = Q3 + k * IQR
axes[i].axvline(x=lower, color='r', linestyle='--', label='lower')
axes[i].axvline(x=upper, color='b', linestyle='--', label='upper')
# Tweak the visual presentation
#sns.despine(ax=axes[i],trim=True, left=True)
axes[i].text(lower, 0.8, f'lower = {lower}', color='red', fontsize=12)
axes[i].text(upper, 0.8, f'upper = {upper}', color='blue', fontsize=12)
axes[i].set_title(f'{feature}')
axes[i].set_xlabel('')
plt.tight_layout()
plt.show()
3.2.2 异常值离群值处理
接下来我们将使用 IQR
方法检测并替换异常值
python
def Handle_outlier(data, column):
# 检测并替换异常值
# Detecting outliers with IQR
Q1 = data[column].quantile(0.25)
Q3 = data[column].quantile(0.75)
IQR = Q3 - Q1
lower = Q1 - k * IQR
upper = Q3 + k * IQR
# 使用 np.where 和 np.logical_or 处理异常值
data[column] = np.where(
np.logical_or(data[column] < lower, data[column] > upper),
np.select([data[column] < lower, data[column] > upper], [lower, upper]), data[column])
# data = data[(data[column] >= lower) & (data[column] <= upper)]
return data
python
data = Handle_outlier(data, column = 'MinTemp')
data = Handle_outlier(data, column = 'MaxTemp')
data = Handle_outlier(data, column = 'Rainfall')
data = Handle_outlier(data, column = 'Evaporation')
data = Handle_outlier(data, column = 'Sunshine')
data = Handle_outlier(data, column = 'WindGustSpeed')
data = Handle_outlier(data, column = 'WindSpeed9am')
data = Handle_outlier(data, column = 'WindSpeed3pm')
data = Handle_outlier(data, column = 'Humidity9am')
data = Handle_outlier(data, column = 'Humidity3pm')
data = Handle_outlier(data, column = 'Pressure9am')
data = Handle_outlier(data, column = 'Pressure3pm')
data = Handle_outlier(data, column = 'Cloud9am')
data = Handle_outlier(data, column = 'Cloud3pm')
data = Handle_outlier(data, column = 'Temp9am')
data = Handle_outlier(data, column = 'Temp3pm')
data = data.reset_index(drop=True)
4. 特征工程
4.1 Label编码和One-hot编码
对分类变量进行标签编码,离散值特征进行One-hot编码
python
# Apply label encoder to RainToday, RainTomorrow
le = LabelEncoder()
data['RainToday'] = le.fit_transform(data['RainToday'])
data['RainTomorrow'] = le.fit_transform(data['RainTomorrow'])
ohe = OneHotEncoder() # 离散值特征One-hot编码
encoded = ohe.fit_transform(data[['Location',
'WindGustDir',
'WindDir9am',
'WindDir3pm'
]])
encoded_data = pd.DataFrame(encoded.toarray(),columns = ohe.get_feature_names_out())
python
data = pd.concat([data,encoded_data],axis=1)
data = data.drop(['Location',
'WindGustDir',
'WindDir9am',
'WindDir3pm'], axis =1)
print(data.info())
python
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Columns: 116 entries, Date to WindDir3pm_WSW
dtypes: datetime64[ns](1), float64(113), int32(2)
memory usage: 127.6 MB
None
现在让我们使用 .corr()
函数来看看数据之间的相关性:
python
correlation = data.corr()
print(correlation["RainTomorrow"].sort_values(ascending=False))
python
RainTomorrow 1.000000
Humidity3pm 0.433167
Rainfall 0.323354
RainToday 0.305744
Cloud3pm 0.291963
...
MaxTemp -0.156313
Temp3pm -0.187675
Pressure3pm -0.209378
Pressure9am -0.228542
Sunshine -0.288223
Name: RainTomorrow, Length: 116, dtype: float64
4.2 特征缩放(归一化)
StandardScaler()
函数将数据的特征值转换为符合正态分布的形式,它将数据缩放到均值为0,标准差为1的区间。在机器学习中,StandardScaler()
函数常用于不同尺度特征数据的标准化,以提高模型的泛化能力。
python
# dividing the future and the target from the dataset
features = data.drop(['Date', 'RainTomorrow'], axis=1)
target = data['RainTomorrow'].values.reshape(-1, 1)
python
# 创建 StandardScaler实例,对特征进行拟合和变换,生成NumPy数组
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
print(features_scaled)
4.3 构建时间序列数据
python
time_steps = 10
X_list = []
y_list = []
for i in range(len(features_scaled) - time_steps):
X_list.append(features_scaled[i:i+time_steps])
y_list.append(target[i+time_steps])
X = np.array(X_list) # [samples, time_steps, num_features]
y = np.array(y_list) # [target]
4.4 数据集过采样 SMOTE
SMOTE (synthetic minority oversampling technique) 合成少数群体超采样技术是解决不平衡问题最常用的超采样方法之一。它的目的是通过复制少数类实例来随机增加少数类实例,从而平衡类的分布。SMOTE 在现有的少数类实例之间合成新的少数类实例。它通过线性插值为少数类生成虚拟训练记录。这些合成训练记录是通过为少数群体中的每个实例随机选择一个或多个 k 近邻来生成的。过采样过程结束后,数据将被重建,并可对处理后的数据应用多个分类模型。
python
samples, time_steps, num_features = X.shape
# 将 X重塑为二维数组,因为SMOTE期望二维输入
X_reshaped = X.reshape(samples, time_steps * num_features)
# Oversampling
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_reshaped, y)
# 这里我们将 X_resampled 重新塑形为原始的三维形状
X_resampled = X_resampled.reshape(-1, time_steps, num_features)
4.5 数据集划分
python
X_train, X_valid,\
y_train, y_valid = train_test_split(X_resampled, y_resampled,
test_size=0.25,
random_state=12345)
print(X_train.shape, X_valid.shape, y_train.shape, y_valid.shape)
4.6 数据集张量
python
# 将 NumPy数组转换为 tensor张量
X_train_tensor = torch.from_numpy(X_train).type(torch.Tensor)
X_valid_tensor = torch.from_numpy(X_valid).type(torch.Tensor)
y_train_tensor = torch.from_numpy(y_train).type(torch.Tensor).view(-1, 1)
y_valid_tensor = torch.from_numpy(y_valid).type(torch.Tensor).view(-1, 1)
print(X_train_tensor.shape, X_valid_tensor.shape, y_train_tensor.shape, y_valid_tensor.shape)
python
torch.Size([170361, 10, 114]) torch.Size([56787, 10, 114]) torch.Size([170361, 1]) torch.Size([56787, 1])
.type(torch.Tensor)
明确将该张量的数据类型指定为 torch.Tensor
, 而.type(torch.long)
明确将标签的张量数据类型指定为长整型torch.long
。这通常用于表示整数类型的标签
python
class DataHandler(Dataset):
def __init__(self, X_train_tensor, y_train_tensor, X_valid_tensor, y_valid_tensor):
self.X_train_tensor = X_train_tensor
self.y_train_tensor = y_train_tensor
self.X_valid_tensor = X_valid_tensor
self.y_valid_tensor = y_valid_tensor
def __len__(self):
return len(self.X_train_tensor)
def __getitem__(self, idx):
sample = self.X_train_tensor[idx]
labels = self.y_train_tensor[idx]
return sample, labels
def train_loader(self):
train_dataset = TensorDataset(self.X_train_tensor, self.y_train_tensor)
return DataLoader(train_dataset, batch_size=32, shuffle=True)
def valid_loader(self):
valid_dataset = TensorDataset(self.X_valid_tensor, self.y_valid_tensor)
return DataLoader(valid_dataset, batch_size=32, shuffle=False)
在上述代码中,定义了一个名为 TSCDataset 的类,它继承自 torch.utils.data.Dataset
__init__
方法用于接收数据和标签。
__len__
方法返回数据集的长度。
__getitem__
方法根据给定的索引 idx
返回相应的数据样本和标签。
python
data_handler = DataHandler(X_train_tensor, y_train_tensor, X_valid_tensor, y_valid_tensor)
train_loader = data_handler.train_loader()
valid_loader = data_handler.valid_loader()
5. 构建时序模型(TSC)
5.1 构建TimeSeriesCNN模型
TimeSeriesCNN(时间序列卷积神经网络)是一种专门用于处理时间序列数据的深度学习模型。时间序列数据是指按照时间顺序收集的数据点序列,广泛存在于各个领域,如金融市场的股票价格、气象数据中的温度变化、传感器数据的监测记录等。
该模型基于卷积神经网络(CNN)的架构,通过对时间序列数据进行卷积操作,能够自动学习数据中的特征模式,从而实现对未来时间点的预测或对时间序列的分类等任务。
python
class TimeSeriesCNN(nn.Module):
def __init__(self, input_dim, output_dim):
super(TimeSeriesCNN, self).__init__()
self.conv1 = nn.Conv1d(input_dim, 128, kernel_size=3)
self.conv2 = nn.Conv1d(128, 64, kernel_size=3)
self.pool = nn.MaxPool1d(kernel_size=2)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(64 * ((time_steps - 4) // 4), 16)
self.dropout1 = nn.Dropout(p=0.4) # 添加第一个 Dropout 层,设置丢弃概率为 0.4
self.fc2 = nn.Linear(16, output_dim)
self.dropout2 = nn.Dropout(p=0.2) # 添加第二个 Dropout 层,设置丢弃概率为 0.2
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout1(x) # 在第一个全连接层后应用 Dropout
x = self.fc2(x)
x = self.dropout2(x) # 在输出层前应用 Dropout
return x
5.2 定义模型、损失函数与优化器
python
model = TimeSeriesCNN(input_dim = num_features, output_dim = 1)
criterion = torch.nn.BCEWithLogitsLoss() # 定义二进制交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-05) # 定义优化器
5.3 模型概要
python
summary(model, (32, time_steps, num_features)) # batch_size, seq_len(time_steps), input_size(in_channels)
python
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
TimeSeriesCNN [32, 1] --
├─Conv1d: 1-1 [32, 128, 8] 43,904
├─ReLU: 1-2 [32, 128, 8] --
├─MaxPool1d: 1-3 [32, 128, 4] --
├─Conv1d: 1-4 [32, 64, 2] 24,640
├─ReLU: 1-5 [32, 64, 2] --
├─MaxPool1d: 1-6 [32, 64, 1] --
├─Flatten: 1-7 [32, 64] --
├─Linear: 1-8 [32, 16] 1,040
├─ReLU: 1-9 [32, 16] --
├─Dropout: 1-10 [32, 16] --
├─Linear: 1-11 [32, 1] 17
├─Dropout: 1-12 [32, 1] --
==========================================================================================
Total params: 69,601
Trainable params: 69,601
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 12.85
==========================================================================================
Input size (MB): 0.15
Forward/backward pass size (MB): 0.30
Params size (MB): 0.28
Estimated Total Size (MB): 0.72
==========================================================================================
6. 模型训练与可视化
6.1 定义训练与评估函数
定义binary_accuracy
函数来衡量模型性能
python
def binary_accuracy(outputs, labels):
# 通过 sigmoid 函数将输出值映射到 [0, 1] 区间
outputs = torch.sigmoid(outputs)
# 将输出值与 0.5 比较,得到预测的类别(0 或 1)
predicted = (outputs > 0.5).float()
# 计算预测正确的数量
correct = (predicted == labels).float().sum()
# 计算总样本数量
total = labels.size(0)
# 计算准确率
accuracy = correct / total
return accuracy
上述代码,定义了一个名为 binary_accuracy
的函数,用于计算二分类任务中的准确率。它接收模型的输出结果 outputs
和真实标签 labels
作为参数,并返回计算得到的准确率值。
python
def train(model, iterator, optimizer, criterion):
epoch_loss = 0
epoch_acc = 0
model.train() # 确保模型处于训练模式
for batch in iterator:
optimizer.zero_grad() # 清空梯度
inputs, labels = batch # 获取输入和标签
outputs = model(inputs) # 前向传播
# 计算损失和准确率
loss = criterion(outputs, labels)
acc = binary_accuracy(outputs, labels)
loss.backward()
optimizer.step()
# 累积损失和准确率
epoch_loss += loss.item()
epoch_acc += acc
# 计算平均损失和准确率
average_loss = epoch_loss / len(iterator)
average_acc = epoch_acc / len(iterator)
return average_loss, average_acc
上述代码定义了一个名为 train
的函数,用于训练给定的模型。它接收模型、数据迭代器、优化器和损失函数作为参数,并返回训练过程中的平均损失和平均准确率。
python
def evaluate(model, iterator, criterion):
epoch_loss = 0
epoch_acc = 0
model.eval() # 将模型设置为评估模式,例如关闭 Dropout 等
with torch.no_grad(): # 不需要计算梯度
for batch in iterator:
inputs, labels = batch
outputs = model(inputs) # 前向传播
# 计算损失和准确率
loss = criterion(outputs, labels)
acc = binary_accuracy(outputs, labels)
# 累计损失和准确率
epoch_loss += loss.item()
epoch_acc += acc
return epoch_loss / len(iterator), epoch_acc / len(iterator)
上述代码定义了一个名为 evaluate
的函数,用于评估给定模型在给定数据迭代器上的性能。它接收模型、数据迭代器和损失函数作为参数,并返回评估过程中的平均损失和平均准确率。这个函数通常在模型训练的过程中定期被调用,以监控模型在验证集或测试集上的性能。通过评估模型的性能,可以了解模型的泛化能力和训练的进展情况。
python
best_acc = 0
epoch = 100
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []
for epoch in range(epoch):
train_loss, train_acc = train(model, train_loader, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_loader, criterion)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
train_accs.append(train_acc)
valid_accs.append(valid_acc)
print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc * 100:.2f}%, Val. Loss: {valid_loss:.3f}, Val. Acc: {valid_acc * 100:.2f}%')
if best_acc <= valid_acc:
best_acc = valid_acc
pth = model.state_dict()
python
Epoch: 01, Train Loss: 0.671, Train Acc: 57.89%, Val. Loss: 0.648, Val. Acc: 63.78%
Epoch: 02, Train Loss: 0.650, Train Acc: 61.34%, Val. Loss: 0.630, Val. Acc: 66.18%
Epoch: 03, Train Loss: 0.635, Train Acc: 63.13%, Val. Loss: 0.615, Val. Acc: 67.68%
Epoch: 04, Train Loss: 0.623, Train Acc: 64.47%, Val. Loss: 0.599, Val. Acc: 69.12%
Epoch: 05, Train Loss: 0.610, Train Acc: 65.71%, Val. Loss: 0.585, Val. Acc: 70.85%
******
Epoch: 96, Train Loss: 0.338, Train Acc: 82.18%, Val. Loss: 0.379, Val. Acc: 82.73%
Epoch: 97, Train Loss: 0.338, Train Acc: 82.32%, Val. Loss: 0.376, Val. Acc: 82.82%
Epoch: 98, Train Loss: 0.338, Train Acc: 82.31%, Val. Loss: 0.368, Val. Acc: 83.33%
Epoch: 99, Train Loss: 0.335, Train Acc: 82.44%, Val. Loss: 0.368, Val. Acc: 83.42%
Epoch: 100, Train Loss: 0.334, Train Acc: 82.51%, Val. Loss: 0.373, Val. Acc: 83.16%
6.2 绘制损失与准确率曲线
python
# 绘制损失图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.grid(True)
# 绘制准确率图
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(valid_accs, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Train and Validation Accuracy')
plt.legend()
plt.grid(True)
plt.show()
7. 模型评估与可视化
7.1 构建预测函数
定义预测函数prediction
方便调用
python
# 定义 prediction函数
def prediction(model, valid_loader):
all_labels = []
all_predictions = []
all_predictions_prob = []
model.eval()
with torch.no_grad():
for inputs, labels in valid_loader:
outputs = model(inputs)
predictions_prob = torch.sigmoid(outputs)
predicted = (predictions_prob > 0.5).float()
all_labels.extend(labels.numpy())
all_predictions.extend(predicted.numpy())
all_predictions_prob.extend(predictions_prob.numpy())
return all_labels, all_predictions, all_predictions_prob
上述代码定义了一个名为 prediction
的函数,用于对给定的模型在验证数据加载器(valid_loader)上进行预测,并返回真实标签、预测的类别以及预测的概率。这个函数通常在模型训练完成后,用于对新的数据进行预测。通过收集所有的预测结果,可以进一步分析模型的性能,例如计算准确率、绘制混淆矩阵等。它也可以用于实际应用中,对未知数据进行预测并做出决策。
python
# 预测结果
labels, predictions, predictions_prob = prediction(model, valid_loader)
7.2 混淆矩阵
python
def plot_confusion_matrix(labels, predictions, classes):
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
上述代码定义一个名为 plot_confusion_matrix
的函数,用于绘制给定真实标签和预测结果的混淆矩阵。混淆矩阵是一种用于评估分类模型性能的可视化工具,它展示了模型在不同类别上的预测准确性。
python
classes = ['Class 0', 'Class 1']
绘制混淆矩阵
python
plot_confusion_matrix(labels, predictions, classes)
7.3 ROC_AUC曲线
python
def plot_roc_curve(labels, predictions_prob):
fpr, tpr, _ = roc_curve(labels, predictions_prob)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
python
# 绘制 ROC曲线
plot_roc_curve(labels, predictions_prob)
7.4 分类报告
python
from sklearn.metrics import classification_report
print(classification_report(labels, predictions))
python
precision recall f1-score support
0.0 0.81 0.86 0.84 28186
1.0 0.85 0.80 0.83 28601
accuracy 0.83 56787
macro avg 0.83 0.83 0.83 56787
weighted avg 0.83 0.83 0.83 56787