OpenSTL PredRNNv2 模型复现与自定义数据集训练
概述
本文将详细介绍如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们将从环境配置开始,逐步讲解数据预处理、模型构建、训练过程和预测实现,最终实现输入多张连续时间序列的 500×500 图像并输出相应数量预测图像的目标。
目录
- 环境配置与依赖安装
- 数据集准备与预处理
- PredRNNv2 模型原理与架构
- 数据加载器实现
- 模型训练流程
- 预测与结果可视化
- 模型评估与优化
- 完整代码实现
- 常见问题与解决方案
- 总结与展望
1. 环境配置与依赖安装
首先,我们需要创建一个合适的 Python 环境并安装所有必要的依赖包。
bash
# 创建conda环境
conda create -n openstl python=3.8
conda activate openstl
# 安装PyTorch (根据CUDA版本选择)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
# 安装其他依赖
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0
接下来,我们需要克隆 OpenSTL 仓库并安装相关依赖:
bash
git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .
2. 数据集准备与预处理
我们的数据集是 NPY 格式的文件,每张图像尺寸为 500×500,且文件之间在时间上是连续的。首先,我们需要了解数据集的目录结构:
dataset/
├── train/
│ ├── sequence_001/
│ │ ├── frame_001.npy
│ │ ├── frame_002.npy
│ │ └── ...
│ ├── sequence_002/
│ └── ...
├── valid/
└── test/
2.1 数据预处理类实现
我们需要创建一个数据预处理类,将 NPY 文件转换为模型可用的格式:
python
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2
class NPYDataset(Dataset):
def __init__(self, data_root, mode='train', input_frames=10, output_frames=10,
future_frames=10, transform=None, preprocess=True):
"""
初始化NPY数据集
参数:
data_root: 数据根目录
mode: 模式 ('train', 'valid', 'test')
input_frames: 输入帧数
output_frames: 输出帧数
future_frames: 未来帧数 (预测帧数)
transform: 数据转换函数
preprocess: 是否进行预处理
"""
self.data_root = os.path.join(data_root, mode)
self.mode = mode
self.input_frames = input_frames
self.output_frames = output_frames
self.future_frames = future_frames
self.transform = transform
self.preprocess = preprocess
# 获取所有序列
self.sequences = []
for seq_name in os.listdir(self.data_root):
seq_path = os.path.join(self.data_root, seq_name)
if os.path.isdir(seq_path):
frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])
if len(frames) >= input_frames + future_frames:
self.sequences.append((seq_path, frames))
# 数据标准化器
self.scaler = None
if preprocess:
self._init_scaler()
def _init_scaler(self):
"""初始化数据标准化器"""
print(f"Initializing scaler for {self.mode} mode...")
all_data = []
for seq_path, frames in self.sequences:
for frame_name in frames[:min(100, len(frames))]: # 使用前100帧计算统计量
frame_path = os.path.join(seq_path, frame_name)
data = np.load(frame_path)
all_data.append(data.flatten())
all_data = np.concatenate(all_data).reshape(-1, 1)
self.scaler = StandardScaler()
self.scaler.fit(all_data)
print("Scaler initialized.")
def _preprocess_data(self, data):
"""预处理数据"""
if self.preprocess and self.scaler is not None:
original_shape = data.shape
data = data.flatten().reshape(-1, 1)
data = self.scaler.transform(data)
data = data.reshape(original_shape)
return data
def _postprocess_data(self, data):
"""后处理数据"""
if self.preprocess and self.scaler is not None:
original_shape = data.shape
data = data.flatten().reshape(-1, 1)
data = self.scaler.inverse_transform(data)
data = data.reshape(original_shape)
return data
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
seq_path, frames = self.sequences[idx]
# 随机选择起始帧
total_frames = len(frames)
max_start = total_frames - self.input_frames - self.future_frames
start_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0
# 加载输入帧
input_frames = []
for i in range(start_idx, start_idx + self.input_frames):
frame_path = os.path.join(seq_path, frames[i])
frame_data = np.load(frame_path)
frame_data = self._preprocess_data(frame_data)
input_frames.append(frame_data)
# 加载目标帧
target_frames = []
for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):
frame_path = os.path.join(seq_path, frames[i])
frame_data = np.load(frame_path)
frame_data = self._preprocess_data(frame_data)
target_frames.append(frame_data)
# 转换为numpy数组
input_seq = np.stack(input_frames, axis=0)
target_seq = np.stack(target_frames, axis=0)
# 添加通道维度
input_seq = np.expand_dims(input_seq, axis=1) # [T, 1, H, W]
target_seq = np.expand_dims(target_seq, axis=1) # [T, 1, H, W]
# 转换为张量
input_seq = torch.FloatTensor(input_seq)
target_seq = torch.FloatTensor(target_seq)
if self.transform:
input_seq = self.transform(input_seq)
target_seq = self.transform(target_seq)
return input_seq, target_seq
# 数据增强转换
class RandomRotate:
def __init__(self, angles=[0, 90, 180, 270]):
self.angles = angles
def __call__(self, x):
angle = np.random.choice(self.angles)
if angle == 0:
return x
# 旋转每个帧
rotated = []
for i in range(x.shape[0]):
frame = x[i].numpy()
# 对于3D数据,我们需要分别旋转每个通道
if len(frame.shape) == 3:
frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE)
for c in range(frame.shape[0])], axis=0)
else:
frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
rotated.append(frame_rotated)
return torch.FloatTensor(np.stack(rotated, axis=0))
class RandomFlip:
def __init__(self, p=0.5):
self.p = p
def __call__(self, x):
if np.random.random() < self.p:
# 水平翻转
return x.flip(-1)
return x
3. PredRNNv2 模型原理与架构
PredRNNv2 是一种改进的循环神经网络,专门用于视频预测任务。它通过引入时空记忆(STM)单元来更好地捕捉时空动态。
3.1 核心组件
python
import torch
import torch.nn as nn
from einops import rearrange
class SpatioTemporalLSTMCell(nn.Module):
def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
super(SpatioTemporalLSTMCell, self).__init__()
self.num_hidden = num_hidden
self.padding = filter_size // 2
self._forget_bias = 1.0
# 卷积层
self.conv_x = nn.Sequential(
nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
stride=stride, padding=self.padding, bias=False),
nn.LayerNorm([num_hidden * 7, height, width])
)
self.conv_h = nn.Sequential(
nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
stride=stride, padding=self.padding, bias=False),
nn.LayerNorm([num_hidden * 4, height, width])
)
self.conv_m = nn.Sequential(
nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
stride=stride, padding=self.padding, bias=False),
nn.LayerNorm([num_hidden * 3, height, width])
)
self.conv_o = nn.Sequential(
nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
stride=stride, padding=self.padding, bias=False),
nn.LayerNorm([num_hidden, height, width])
)
self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
stride=1, padding=0, bias=False)
def forward(self, x_t, h_t, c_t, m_t):
# 计算门控信号
x_concat = self.conv_x(x_t)
h_concat = self.conv_h(h_t)
m_concat = self.conv_m(m_t)
i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)
i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
i_t = torch.sigmoid(i_x + i_h)
f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
g_t = torch.tanh(g_x + g_h)
c_new = f_t * c_t + i_t * g_t
i_t_prime = torch.sigmoid(i_x_prime + i_m)
f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
g_t_prime = torch.tanh(g_x_prime + g_m)
m_new = f_t_prime * m_t + i_t_prime * g_t_prime
mem = torch.cat((c_new, m_new), 1)
o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
h_new = o_t * torch.tanh(self.conv_last(mem))
return h_new, c_new, m_new
class PredRNNv2(nn.Module):
def __init__(self, configs):
super(PredRNNv2, self).__init__()
self.configs = configs
self.frame_channel = configs.patch_size * configs.patch_size * configs.img_channel
self.num_layers = len(configs.num_hidden)
self.num_hidden = configs.num_hidden
self.device = configs.device
# 构建网络
cell_list = []
height = configs.img_height // configs.patch_size
width = configs.img_width // configs.patch_size
for i in range(self.num_layers):
in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]
cell_list.append(
SpatioTemporalLSTMCell(
in_channel, self.num_hidden[i], height, width,
configs.filter_size, configs.stride, configs.layer_norm
)
)
self.cell_list = nn.ModuleList(cell_list)
# 输出层
self.conv_last = nn.Conv2d(
self.num_hidden[self.num_layers-1], self.frame_channel,
kernel_size=1, stride=1, padding=0, bias=False
)
def forward(self, frames_tensor, mask_true):
# frames_tensor: [batch, length, channel, height, width]
batch = frames_tensor.shape[0]
height = frames_tensor.shape[3]
width = frames_tensor.shape[4]
# 初始化隐藏状态和记忆状态
next_frames = []
h_t = []
c_t = []
m_t = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)
h_t.append(zeros)
c_t.append(zeros)
m_t.append(zeros)
# 记忆状态
memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)
# 序列长度
seq_length = self.configs.input_length + self.configs.total_length
for t in range(seq_length - 1):
# 反向调度采样
if self.configs.reverse_scheduled_sampling == 1:
if t == 0:
net = frames_tensor[:, t]
else:
# 从真实数据或预测数据中采样
net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen
else:
# 常规训练
if t < self.configs.input_length:
net = frames_tensor[:, t]
else:
# 从真实数据或预测数据中采样
net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen
# 第一层
h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])
# 后续层
for i in range(1, self.num_layers):
h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])
# 生成预测
x_gen = self.conv_last(h_t[self.num_layers-1])
next_frames.append(x_gen)
# [length, batch, channel, height, width] -> [batch, length, channel, height, width]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)
return next_frames
4. 数据加载器实现
接下来,我们需要实现数据加载器,将数据集转换为模型可用的格式:
python
def create_data_loaders(configs):
"""创建训练、验证和测试数据加载器"""
# 数据转换
if configs.data_augmentation:
train_transform = nn.Sequential(
RandomRotate(),
RandomFlip()
)
else:
train_transform = None
# 创建数据集
train_dataset = NPYDataset(
data_root=configs.data_root,
mode='train',
input_frames=configs.input_length,
output_frames=configs.total_length - configs.input_length,
future_frames=configs.total_length - configs.input_length,
transform=train_transform,
preprocess=configs.preprocess_data
)
valid_dataset = NPYDataset(
data_root=configs.data_root,
mode='valid',
input_frames=configs.input_length,
output_frames=configs.total_length - configs.input_length,
future_frames=configs.total_length - configs.input_length,
transform=None,
preprocess=configs.preprocess_data
)
test_dataset = NPYDataset(
data_root=configs.data_root,
mode='test',
input_frames=configs.input_length,
output_frames=configs.total_length - configs.input_length,
future_frames=configs.total_length - configs.input_length,
transform=None,
preprocess=configs.preprocess_data
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=configs.batch_size,
shuffle=True,
num_workers=configs.num_workers,
pin_memory=True
)
valid_loader = DataLoader(
valid_dataset,
batch_size=configs.batch_size,
shuffle=False,
num_workers=configs.num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=configs.batch_size,
shuffle=False,
num_workers=configs.num_workers,
pin_memory=True
)
return train_loader, valid_loader, test_loader
5. 模型训练流程
现在,我们实现完整的训练流程,包括损失函数、优化器和学习率调度器:
python
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdm
class Trainer:
def __init__(self, configs, model, train_loader, valid_loader, test_loader):
self.configs = configs
self.model = model
self.train_loader = train_loader
self.valid_loader = valid_loader
self.test_loader = test_loader
self.device = configs.device
# 损失函数
self.criterion = nn.MSELoss()
# 优化器
self.optimizer = optim.Adam(
model.parameters(),
lr=configs.lr,
weight_decay=configs.weight_decay
)
# 学习率调度器
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
verbose=True
)
# 记录训练历史
self.train_losses = []
self.valid_losses = []
self.best_loss = float('inf')
# 创建检查点目录
os.makedirs(configs.save_dir, exist_ok=True)
def train_epoch(self, epoch):
"""训练一个epoch"""
self.model.train()
total_loss = 0
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
for batch_idx, (inputs, targets) in enumerate(progress_bar):
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# 前向传播
self.optimizer.zero_grad()
outputs = self.model(inputs, mask_true=None)
# 计算损失
loss = self.criterion(outputs, targets)
# 反向传播
loss.backward()
self.optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(self.train_loader)
self.train_losses.append(avg_loss)
return avg_loss
def validate(self):
"""验证模型"""
self.model.eval()
total_loss = 0
with torch.no_grad():
for inputs, targets in self.valid_loader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
outputs = self.model(inputs, mask_true=None)
loss = self.criterion(outputs, targets)
total_loss += loss.item()
avg_loss = total_loss / len(self.valid_loader)
self.valid_losses.append(avg_loss)
return avg_loss
def test(self):
"""测试模型"""
self.model.eval()
total_loss = 0
all_outputs = []
all_targets = []
with torch.no_grad():
for inputs, targets in self.test_loader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
outputs = self.model(inputs, mask_true=None)
loss = self.criterion(outputs, targets)
total_loss += loss.item()
# 保存结果用于后续分析
all_outputs.append(outputs.cpu().numpy())
all_targets.append(targets.cpu().numpy())
avg_loss = total_loss / len(self.test_loader)
return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)
def save_checkpoint(self, epoch, is_best=False):
"""保存检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'train_losses': self.train_losses,
'valid_losses': self.valid_losses,
'best_loss': self.best_loss
}
# 保存最新检查点
torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))
# 如果是最佳模型,保存为最佳检查点
if is_best:
torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.train_losses = checkpoint['train_losses']
self.valid_losses = checkpoint['valid_losses']
self.best_loss = checkpoint['best_loss']
return checkpoint['epoch']
def train(self, num_epochs):
"""完整训练过程"""
start_epoch = 0
# 如果存在检查点,加载检查点
if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):
print("Loading checkpoint...")
start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))
print(f"Resumed from epoch {start_epoch}")
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
# 训练
train_loss = self.train_epoch(epoch)
print(f"Train Loss: {train_loss:.6f}")
# 验证
valid_loss = self.validate()
print(f"Valid Loss: {valid_loss:.6f}")
# 更新学习率
self.scheduler.step(valid_loss)
# 保存检查点
is_best = valid_loss < self.best_loss
if is_best:
self.best_loss = valid_loss
self.save_checkpoint(epoch, is_best)
# 每5个epoch测试一次
if (epoch + 1) % 5 == 0:
test_loss, _, _ = self.test()
print(f"Test Loss: {test_loss:.6f}")
# 最终测试
print("\nFinal Testing...")
test_loss, outputs, targets = self.test()
print(f"Final Test Loss: {test_loss:.6f}")
return test_loss, outputs, targets
6. 预测与结果可视化
实现预测功能和结果可视化:
python
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
class Predictor:
def __init__(self, configs, model):
self.configs = configs
self.model = model
self.device = configs.device
self.model.eval()
def predict(self, input_seq):
"""预测未来帧"""
with torch.no_grad():
input_seq = input_seq.to(self.device)
output_seq = self.model(input_seq, mask_true=None)
return output_seq.cpu()
def visualize_results(self, inputs, targets, predictions, save_path=None):
"""可视化输入、目标和预测结果"""
# 选择第一个批次进行可视化
inputs = inputs[0].squeeze() # [T, H, W]
targets = targets[0].squeeze() # [T, H, W]
predictions = predictions[0].squeeze() # [T, H, W]
# 创建子图
total_frames = inputs.shape[0] + targets.shape[0]
fig = plt.figure(figsize=(20, 10))
grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)
# 绘制输入帧
for i in range(inputs.shape[0]):
ax = grid[i]
ax.imshow(inputs[i], cmap='viridis')
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
for i in range(targets.shape[0]):
ax = grid[inputs.shape[0] + i]
ax.imshow(targets[i], cmap='viridis')
ax.set_title(f'Target {i+1}')
ax.axis('off')
# 绘制预测帧
for i in range(predictions.shape[0]):
ax = grid[inputs.shape[0] + targets.shape[0] + i]
ax.imshow(predictions[i], cmap='viridis')
ax.set_title(f'Pred {i+1}')
ax.axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def save_predictions(self, predictions, save_dir):
"""保存预测结果为NPY文件"""
os.makedirs(save_dir, exist_ok=True)
for i, pred_seq in enumerate(predictions):
for j, frame in enumerate(pred_seq):
frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')
np.save(frame_path, frame.squeeze())
def evaluate_metrics(self, targets, predictions):
"""评估预测性能"""
from sklearn.metrics import mean_squared_error, mean_absolute_error
# 展平数据
targets_flat = targets.flatten()
predictions_flat = predictions.flatten()
# 计算指标
mse = mean_squared_error(targets_flat, predictions_flat)
mae = mean_absolute_error(targets_flat, predictions_flat)
rmse = np.sqrt(mse)
# 计算PSNR
max_val = np.max(targets_flat)
psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')
# 计算SSIM (需要安装skimage)
try:
from skimage.metrics import structural_similarity as ssim_func
ssim = ssim_func(targets_flat.reshape(targets.shape),
predictions_flat.reshape(targets.shape),
data_range=max_val)
except ImportError:
ssim = 0
print("SSIM calculation requires skimage. Install with: pip install scikit-image")
return {
'MSE': mse,
'MAE': mae,
'RMSE': rmse,
'PSNR': psnr,
'SSIM': ssim
}
7. 模型评估与优化
实现模型评估和超参数优化功能:
python
def hyperparameter_optimization(configs):
"""超参数优化"""
import nni
# 获取NNI超参数
optimized_params = nni.get_next_parameter()
configs.lr = optimized_params.get('lr', configs.lr)
configs.batch_size = optimized_params.get('batch_size', configs.batch_size)
configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)
# 创建模型和数据加载器
model = PredRNNv2(configs).to(configs.device)
train_loader, valid_loader, test_loader = create_data_loaders(configs)
# 训练模型
trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)
test_loss, _, _ = trainer.train(configs.epoch)
# 报告最终结果
nni.report_final_result(test_loss)
return test_loss
def analyze_results(configs, outputs, targets):
"""分析预测结果"""
predictor = Predictor(configs, None)
metrics = predictor.evaluate_metrics(targets, outputs)
print("Evaluation Metrics:")
for metric, value in metrics.items():
print(f"{metric}: {value:.4f}")
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)
plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.title('Predictions vs Targets')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)
plt.show()
return metrics
8. 完整代码实现
现在,我们将所有组件整合到一个完整的脚本中:
python
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_results
def parse_args():
parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')
# 数据参数
parser.add_argument('--data_root', type=str, default='./dataset', help='数据集根目录')
parser.add_argument('--input_length', type=int, default=10, help='输入帧数')
parser.add_argument('--total_length', type=int, default=20, help='总帧数(输入+预测)')
parser.add_argument('--img_width', type=int, default=500, help='图像宽度')
parser.add_argument('--img_height', type=int, default=500, help='图像高度')
parser.add_argument('--img_channel', type=int, default=1, help='图像通道数')
parser.add_argument('--preprocess_data', type=bool, default=True, help='是否预处理数据')
parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用数据增强')
# 模型参数
parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每层隐藏单元数')
parser.add_argument('--filter_size', type=int, default=5, help='滤波器大小')
parser.add_argument('--stride', type=int, default=1, help='步长')
parser.add_argument('--patch_size', type=int, default=1, help='补丁大小')
parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用层归一化')
parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向调度采样')
# 训练参数
parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
parser.add_argument('--lr', type=float, default=1e-3, help='学习率')
parser.add_argument('--weight_decay', type=float, default=0, help='权重衰减')
parser.add_argument('--epoch', type=int, default=100, help='训练轮数')
parser.add_argument('--num_workers', type=int, default=4, help='数据加载工作线程数')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='设备')
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目录')
parser.add_argument('--resume', type=bool, default=False, help='是否恢复训练')
# 其他参数
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='运行模式')
parser.add_argument('--checkpoint_path', type=str, default='', help='检查点路径')
return parser.parse_args()
def main():
# 解析参数
configs = parse_args()
# 创建保存目录
os.makedirs(configs.save_dir, exist_ok=True)
# 创建模型
model = PredRNNv2(configs).to(configs.device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
if configs.mode == 'train':
# 创建数据加载器
train_loader, valid_loader, test_loader = create_data_loaders(configs)
# 创建训练器并开始训练
trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)
test_loss, outputs, targets = trainer.train(configs.epoch)
# 分析结果
analyze_results(configs, outputs, targets)
elif configs.mode == 'test':
# 加载检查点
if configs.checkpoint_path:
checkpoint = torch.load(configs.checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from {configs.checkpoint_path}")
# 创建数据加载器
_, _, test_loader = create_data_loaders(configs)
# 测试模型
trainer = Trainer(configs, model, None, None, test_loader)
test_loss, outputs, targets = trainer.test()
print(f"Test Loss: {test_loss:.6f}")
# 分析结果
metrics = analyze_results(configs, outputs, targets)
# 保存结果
np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)
np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)
elif configs.mode == 'predict':
# 加载检查点
if configs.checkpoint_path:
checkpoint = torch.load(configs.checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from {configs.checkpoint_path}")
# 创建预测器
predictor = Predictor(configs, model)
# 加载要预测的数据
# 这里假设有一个单独的预测数据集
predict_dataset = NPYDataset(
data_root=configs.data_root,
mode='predict',
input_frames=configs.input_length,
output_frames=configs.total_length - configs.input_length,
future_frames=configs.total_length - configs.input_length,
transform=None,
preprocess=configs.preprocess_data
)
predict_loader = DataLoader(
predict_dataset,
batch_size=configs.batch_size,
shuffle=False,
num_workers=configs.num_workers,
pin_memory=True
)
all_predictions = []
all_inputs = []
with torch.no_grad():
for inputs, _ in predict_loader:
inputs = inputs.to(configs.device)
predictions = predictor.predict(inputs)
all_predictions.append(predictions.numpy())
all_inputs.append(inputs.cpu().numpy())
all_predictions = np.concatenate(all_predictions, axis=0)
all_inputs = np.concatenate(all_inputs, axis=0)
# 保存预测结果
output_dir = os.path.join(configs.save_dir, 'predictions')
os.makedirs(output_dir, exist_ok=True)
for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):
# 保存输入序列
for j, frame in enumerate(input_seq):
frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')
np.save(frame_path, frame.squeeze())
# 保存预测序列
for j, frame in enumerate(pred_seq):
frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')
np.save(frame_path, frame.squeeze())
print(f"Predictions saved to {output_dir}")
# 可视化一些结果
if len(all_inputs) > 0:
sample_idx = 0
predictor.visualize_results(
all_inputs[sample_idx:sample_idx+1],
all_predictions[sample_idx:sample_idx+1],
all_predictions[sample_idx:sample_idx+1],
save_path=os.path.join(output_dir, 'sample_prediction.png')
)
if __name__ == '__main__':
main()
9. 常见问题与解决方案
9.1 内存不足问题
当处理 500×500 的大尺寸图像时,可能会遇到内存不足的问题。解决方案:
- 使用数据分块:将大图像分割成小块进行处理
- 降低批次大小:减少每次处理的样本数量
- 使用混合精度训练:使用半精度浮点数减少内存占用
python
# 混合精度训练示例
from torch.cuda.amp import autocast, GradScaler
def train_epoch_with_amp(self, epoch):
"""使用混合精度训练一个epoch"""
self.model.train()
total_loss = 0
scaler = GradScaler()
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
for batch_idx, (inputs, targets) in enumerate(progress_bar):
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# 使用自动混合精度
with autocast():
outputs = self.model(inputs, mask_true=None)
loss = self.criterion(outputs, targets)
# 缩放损失并反向传播
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(self.train_loader)
self.train_losses.append(avg_loss)
return avg_loss
9.2 训练不稳定问题
PredRNNv2 模型训练可能会不稳定,可以尝试以下方法:
- 梯度裁剪:防止梯度爆炸
- 学习率调度:动态调整学习率
- 权重初始化:使用合适的初始化方法
python
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):
"""带梯度裁剪的训练"""
self.model.train()
total_loss = 0
progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
for batch_idx, (inputs, targets) in enumerate(progress_bar):
inputs = inputs.to(self.device)
targets = targets.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(inputs, mask_true=None)
loss = self.criterion(outputs, targets)
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)
self.optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(self.train_loader)
self.train_losses.append(avg_loss)
return avg_loss
9.3 过拟合问题
当模型在训练集上表现良好但在验证集上表现不佳时,可能存在过拟合问题:
- 数据增强:增加数据多样性
- 正则化:使用 Dropout 或权重衰减
- 早停:在验证损失不再改善时停止训练
python
# 早停实现
class EarlyStopping:
def __init__(self, patience=10, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0
return self.early_stop
# 在训练循环中使用早停
early_stopping = EarlyStopping(patience=10)
for epoch in range(num_epochs):
# 训练和验证...
if early_stopping(valid_loss):
print("Early stopping triggered")
break
10. 总结与展望
本文详细介绍了如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们涵盖了从环境配置、数据预处理、模型构建到训练和评估的完整流程。
10.1 主要成果
- 完整的数据处理流程:实现了针对 NPY 格式数据的加载、预处理和增强功能
- PredRNNv2 模型复现:成功实现了 PredRNNv2 模型的核心组件和完整架构
- 训练框架:构建了完整的训练、验证和测试流程,包括损失函数、优化器和学习率调度
- 预测与可视化:实现了预测功能和结果可视化,便于分析模型性能
- 问题解决方案:提供了针对常见问题(内存不足、训练不稳定、过拟合)的解决方案
10.2 未来工作方向
- 模型优化:尝试更先进的视频预测模型,如 SimVP、PhyDNet 等
- 多模态融合:结合其他传感器数据(如气象数据、地理信息)提高预测精度
- 实时预测:优化模型推理速度,实现实时预测功能
- 不确定性量化:增加对预测结果不确定性的估计
- 部署优化:将模型部署到生产环境,支持大规模数据处理
通过本文的指导和代码实现,读者应该能够成功复现 PredRNNv2 模型,并在自己的数据集上进行训练和预测。希望这项工作能够为视频预测任务的研究和应用提供有价值的参考。