3D CNN-GRU-Att结合模型:负荷预测、PM2.5预测、光伏预测等实验的代码实现

三维卷积神经网络和门控循环单元加上注意力机制(3D CNN-GRU-Att)的结合上构建的,此代码可以运用到负荷预测、PM2.5预测、光伏预测中等等,只需要替换你自己的数据即可跑所有实验图都是代码跑出来的

江湖上一直流传着时空预测的传说,今天咱们来盘一盘这个能打十个的3D CNN-GRU-Attention组合拳。这玩意儿在电力负荷预测里能把传统模型按在地上摩擦,在PM2.5预测战场上也能七进七出。不整虚的,直接上代码!

先看这个时空特征提取的狠角色------3D卷积。咱们用Keras实现起来跟切菜似的:

python 复制代码
from keras.layers import Conv3D, Reshape

input_layer = Input(shape=(24, 5, 5, 1))  # 24小时历史数据,5x5空间网格

# 第一层3D卷积暴力提取特征
conv1 = Conv3D(filters=64, kernel_size=(3, 3, 3), activation='relu')(input_layer)

这里用3x3x3的核在时空维度上滑动,就像拿着探照灯在时空立方体里找规律。注意输出的形状会自动保持时间维度,这是后续接GRU的关键。

接下来是时间序列处理的扛把子GRU,配上注意力机制简直如虎添翼:

python 复制代码
from keras.layers import GRU, Dense, Multiply

# 把卷积输出压平时间步长
reshape = Reshape((22, -1))(conv1)  # 24-3+1=22个时间步

# GRU捕捉时间依赖
gru_out, gru_state = GRU(128, return_sequences=True, return_state=True)(reshape)

# 注意力机制搞事情
attention = Dense(1, activation='tanh')(gru_out)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
context = Multiply()([gru_out, attention])

这段代码暗藏玄机:GRU不仅返回最终状态,还把每个时间步的输出都吐出来。注意力层就像个智能聚光灯,自动找到关键时间点重点关照。

整个模型拼装起来就像搭乐高:

python 复制代码
from keras.models import Model

# 拼接输出层
output = Dense(24)(context)  # 预测未来24个时间点

model = Model(inputs=input_layer, outputs=output)
model.compile(optimizer='adam', loss='mape')

这里输出层直接预测多个时间点,比传统递归预测更高效。注意损失函数用了MAPE,对负荷预测这种相对误差敏感的场景特别合适。

数据预处理才是真功夫,以电力负荷数据为例:

python 复制代码
def create_dataset(data, look_back=24, pred_steps=24):
    X, Y = [], []
    for i in range(len(data)-look_back-pred_steps):
        # 3D输入需要空间维度,假设有5x5区域数据
        X.append(data[i:i+look_back].reshape(look_back,5,5,1))
        Y.append(data[i+look_back:i+look_back+pred_steps])
    return np.array(X), np.array(Y)

这个reshape操作把一维时间序列变成伪3D数据,实际业务中可能需要根据传感器位置调整空间维度。比如把不同变电站的数据排成网格。

三维卷积神经网络和门控循环单元加上注意力机制(3D CNN-GRU-Att)的结合上构建的,此代码可以运用到负荷预测、PM2.5预测、光伏预测中等等,只需要替换你自己的数据即可跑所有实验图都是代码跑出来的

训练时记得用时空数据增强:

python 复制代码
from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rotation_range=20,  # 空间旋转增强
    width_shift_range=0.2,
    height_shift_range=0.2)

model.fit_generator(datagen.flow(X_train, y_train, batch_size=32),
                    steps_per_epoch=len(X_train)/32, epochs=100)

这里把图像增强技术用在时空数据上,相当于给模型喂了"时空扭曲"的特训套餐,大幅提升泛化能力。

预测的时候玩点花的------滚动预测未来多步:

python 复制代码
def rolling_prediction(model, init_data, steps=24):
    preds = []
    current_batch = init_data.reshape(1, 24, 5, 5, 1)
    
    for _ in range(steps):
        pred = model.predict(current_batch)[0]
        preds.append(pred[0])
        # 更新输入数据,类似滑动窗口
        current_batch = np.concatenate([current_batch[:,1:], pred.reshape(1,1,5,5,1)], axis=1)
    
    return np.array(preds)

这个滚动预测相当于让模型自己续写时间序列,每次预测下一步时都把最新预测值塞回输入窗口,适合长期预测场景。

在光伏预测中实测发现,加入注意力机制后模型对日出日落的时间点特别敏感。比如当注意力权重突然在早晨6点暴增,说明模型自动捕捉到了光伏发电的启动时刻。

代码里有个暗坑要注意:3D卷积会压缩时间维度。比如用kernel_size=(3,3,3)时,输入24个时间步经过卷积后会变成22个时间步。所以在GRU层前面要确保时间步数量合理,别被卷没了。

最后说句大实话:这套模型在1080Ti上跑起来确实有点烫手,建议把空间维度不要超过10x10。实际工业部署时可以改用 separable convolution 省显存,这个在另一个版本里实现了,点赞过500就放出来!

相关推荐
承渊政道1 天前
Linux系统学习【深入剖析Git的原理和使用(下)】
linux·服务器·git·学习·gitee·vim·gitcode
嵌入小生0072 天前
线程 --- 嵌入式(Linux)
linux·vscode·vim·嵌入式·线程·进程
蜡笔小炘4 天前
Haproxy -- 动/静/混合态算法实验
运维·服务器·vim·haproxy
火山引擎开发者社区5 天前
Seedance 2.0上线火山方舟体验中心,API即将开放
docker·vim·emacs
小心草里有鬼5 天前
VMware虚拟机扩容
linux·后端·centos·vim
嵌入小生0077 天前
进程的基本概念\相关命令\创建\调度\状态及相关函数接口---软件编程---嵌入式(Linux)
linux·vscode·vim·嵌入式·进程·fork·软件编程
嵌入小生0078 天前
文件IO\目录IO\时间接口函数 --- IO编程 --- 嵌入式(Linux)
linux·c语言·vscode·vim·文件io·目录io·时间函数接口
程序员一点8 天前
第7章:文本编辑器使用(vi/vim 与 nano)
linux·编辑器·vim
嵌入小生0079 天前
Standard IO -- Continuation of Core Function Interfaces (Embedded Linux)
linux·vim·嵌入式·标准io·vscode