OpenSTL PredRNNv2 模型复现与自定义数据集训练

OpenSTL PredRNNv2 模型复现与自定义数据集训练

概述

本文将详细介绍如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们将从环境配置开始,逐步讲解数据预处理、模型构建、训练过程和预测实现,最终实现输入多张连续时间序列的 500×500 图像并输出相应数量预测图像的目标。

目录

  1. 环境配置与依赖安装
  2. 数据集准备与预处理
  3. PredRNNv2 模型原理与架构
  4. 数据加载器实现
  5. 模型训练流程
  6. 预测与结果可视化
  7. 模型评估与优化
  8. 完整代码实现
  9. 常见问题与解决方案
  10. 总结与展望

1. 环境配置与依赖安装

首先,我们需要创建一个合适的 Python 环境并安装所有必要的依赖包。

bash 复制代码
# 创建conda环境
conda create -n openstl python=3.8
conda activate openstl

# 安装PyTorch (根据CUDA版本选择)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

# 安装其他依赖
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0

接下来,我们需要克隆 OpenSTL 仓库并安装相关依赖:

bash 复制代码
git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .

2. 数据集准备与预处理

我们的数据集是 NPY 格式的文件,每张图像尺寸为 500×500,且文件之间在时间上是连续的。首先,我们需要了解数据集的目录结构:

复制代码
dataset/
├── train/
│   ├── sequence_001/
│   │   ├── frame_001.npy
│   │   ├── frame_002.npy
│   │   └── ...
│   ├── sequence_002/
│   └── ...
├── valid/
└── test/

2.1 数据预处理类实现

我们需要创建一个数据预处理类,将 NPY 文件转换为模型可用的格式:

python 复制代码
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2

class NPYDataset(Dataset):
    def __init__(self, data_root, mode='train', input_frames=10, output_frames=10, 
                 future_frames=10, transform=None, preprocess=True):
        """
        初始化NPY数据集
        
        参数:
            data_root: 数据根目录
            mode: 模式 ('train', 'valid', 'test')
            input_frames: 输入帧数
            output_frames: 输出帧数
            future_frames: 未来帧数 (预测帧数)
            transform: 数据转换函数
            preprocess: 是否进行预处理
        """
        self.data_root = os.path.join(data_root, mode)
        self.mode = mode
        self.input_frames = input_frames
        self.output_frames = output_frames
        self.future_frames = future_frames
        self.transform = transform
        self.preprocess = preprocess
        
        # 获取所有序列
        self.sequences = []
        for seq_name in os.listdir(self.data_root):
            seq_path = os.path.join(self.data_root, seq_name)
            if os.path.isdir(seq_path):
                frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])
                if len(frames) >= input_frames + future_frames:
                    self.sequences.append((seq_path, frames))
        
        # 数据标准化器
        self.scaler = None
        if preprocess:
            self._init_scaler()
    
    def _init_scaler(self):
        """初始化数据标准化器"""
        print(f"Initializing scaler for {self.mode} mode...")
        all_data = []
        for seq_path, frames in self.sequences:
            for frame_name in frames[:min(100, len(frames))]:  # 使用前100帧计算统计量
                frame_path = os.path.join(seq_path, frame_name)
                data = np.load(frame_path)
                all_data.append(data.flatten())
        
        all_data = np.concatenate(all_data).reshape(-1, 1)
        self.scaler = StandardScaler()
        self.scaler.fit(all_data)
        print("Scaler initialized.")
    
    def _preprocess_data(self, data):
        """预处理数据"""
        if self.preprocess and self.scaler is not None:
            original_shape = data.shape
            data = data.flatten().reshape(-1, 1)
            data = self.scaler.transform(data)
            data = data.reshape(original_shape)
        return data
    
    def _postprocess_data(self, data):
        """后处理数据"""
        if self.preprocess and self.scaler is not None:
            original_shape = data.shape
            data = data.flatten().reshape(-1, 1)
            data = self.scaler.inverse_transform(data)
            data = data.reshape(original_shape)
        return data
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq_path, frames = self.sequences[idx]
        
        # 随机选择起始帧
        total_frames = len(frames)
        max_start = total_frames - self.input_frames - self.future_frames
        start_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0
        
        # 加载输入帧
        input_frames = []
        for i in range(start_idx, start_idx + self.input_frames):
            frame_path = os.path.join(seq_path, frames[i])
            frame_data = np.load(frame_path)
            frame_data = self._preprocess_data(frame_data)
            input_frames.append(frame_data)
        
        # 加载目标帧
        target_frames = []
        for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):
            frame_path = os.path.join(seq_path, frames[i])
            frame_data = np.load(frame_path)
            frame_data = self._preprocess_data(frame_data)
            target_frames.append(frame_data)
        
        # 转换为numpy数组
        input_seq = np.stack(input_frames, axis=0)
        target_seq = np.stack(target_frames, axis=0)
        
        # 添加通道维度
        input_seq = np.expand_dims(input_seq, axis=1)  # [T, 1, H, W]
        target_seq = np.expand_dims(target_seq, axis=1)  # [T, 1, H, W]
        
        # 转换为张量
        input_seq = torch.FloatTensor(input_seq)
        target_seq = torch.FloatTensor(target_seq)
        
        if self.transform:
            input_seq = self.transform(input_seq)
            target_seq = self.transform(target_seq)
        
        return input_seq, target_seq

# 数据增强转换
class RandomRotate:
    def __init__(self, angles=[0, 90, 180, 270]):
        self.angles = angles
    
    def __call__(self, x):
        angle = np.random.choice(self.angles)
        if angle == 0:
            return x
        # 旋转每个帧
        rotated = []
        for i in range(x.shape[0]):
            frame = x[i].numpy()
            # 对于3D数据,我们需要分别旋转每个通道
            if len(frame.shape) == 3:
                frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE) 
                                         for c in range(frame.shape[0])], axis=0)
            else:
                frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
            rotated.append(frame_rotated)
        return torch.FloatTensor(np.stack(rotated, axis=0))

class RandomFlip:
    def __init__(self, p=0.5):
        self.p = p
    
    def __call__(self, x):
        if np.random.random() < self.p:
            # 水平翻转
            return x.flip(-1)
        return x

3. PredRNNv2 模型原理与架构

PredRNNv2 是一种改进的循环神经网络,专门用于视频预测任务。它通过引入时空记忆(STM)单元来更好地捕捉时空动态。

3.1 核心组件

python 复制代码
import torch
import torch.nn as nn
from einops import rearrange

class SpatioTemporalLSTMCell(nn.Module):
    def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):
        super(SpatioTemporalLSTMCell, self).__init__()
        
        self.num_hidden = num_hidden
        self.padding = filter_size // 2
        self._forget_bias = 1.0
        
        # 卷积层
        self.conv_x = nn.Sequential(
            nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,
                      stride=stride, padding=self.padding, bias=False),
            nn.LayerNorm([num_hidden * 7, height, width])
        )
        
        self.conv_h = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,
                      stride=stride, padding=self.padding, bias=False),
            nn.LayerNorm([num_hidden * 4, height, width])
        )
        
        self.conv_m = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,
                      stride=stride, padding=self.padding, bias=False),
            nn.LayerNorm([num_hidden * 3, height, width])
        )
        
        self.conv_o = nn.Sequential(
            nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,
                      stride=stride, padding=self.padding, bias=False),
            nn.LayerNorm([num_hidden, height, width])
        )
        
        self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,
                                   stride=1, padding=0, bias=False)
    
    def forward(self, x_t, h_t, c_t, m_t):
        # 计算门控信号
        x_concat = self.conv_x(x_t)
        h_concat = self.conv_h(h_t)
        m_concat = self.conv_m(m_t)
        
        i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)
        i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
        i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)
        
        i_t = torch.sigmoid(i_x + i_h)
        f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
        g_t = torch.tanh(g_x + g_h)
        
        c_new = f_t * c_t + i_t * g_t
        
        i_t_prime = torch.sigmoid(i_x_prime + i_m)
        f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
        g_t_prime = torch.tanh(g_x_prime + g_m)
        
        m_new = f_t_prime * m_t + i_t_prime * g_t_prime
        
        mem = torch.cat((c_new, m_new), 1)
        o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
        h_new = o_t * torch.tanh(self.conv_last(mem))
        
        return h_new, c_new, m_new

class PredRNNv2(nn.Module):
    def __init__(self, configs):
        super(PredRNNv2, self).__init__()
        
        self.configs = configs
        self.frame_channel = configs.patch_size * configs.patch_size * configs.img_channel
        self.num_layers = len(configs.num_hidden)
        self.num_hidden = configs.num_hidden
        self.device = configs.device
        
        # 构建网络
        cell_list = []
        
        height = configs.img_height // configs.patch_size
        width = configs.img_width // configs.patch_size
        
        for i in range(self.num_layers):
            in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]
            cell_list.append(
                SpatioTemporalLSTMCell(
                    in_channel, self.num_hidden[i], height, width,
                    configs.filter_size, configs.stride, configs.layer_norm
                )
            )
        self.cell_list = nn.ModuleList(cell_list)
        
        # 输出层
        self.conv_last = nn.Conv2d(
            self.num_hidden[self.num_layers-1], self.frame_channel,
            kernel_size=1, stride=1, padding=0, bias=False
        )
    
    def forward(self, frames_tensor, mask_true):
        #  frames_tensor: [batch, length, channel, height, width]
        batch = frames_tensor.shape[0]
        height = frames_tensor.shape[3]
        width = frames_tensor.shape[4]
        
        # 初始化隐藏状态和记忆状态
        next_frames = []
        h_t = []
        c_t = []
        m_t = []
        
        for i in range(self.num_layers):
            zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)
            h_t.append(zeros)
            c_t.append(zeros)
            m_t.append(zeros)
        
        # 记忆状态
        memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)
        
        # 序列长度
        seq_length = self.configs.input_length + self.configs.total_length
        
        for t in range(seq_length - 1):
            # 反向调度采样
            if self.configs.reverse_scheduled_sampling == 1:
                if t == 0:
                    net = frames_tensor[:, t]
                else:
                    # 从真实数据或预测数据中采样
                    net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen
            else:
                # 常规训练
                if t < self.configs.input_length:
                    net = frames_tensor[:, t]
                else:
                    # 从真实数据或预测数据中采样
                    net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen
            
            # 第一层
            h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])
            
            # 后续层
            for i in range(1, self.num_layers):
                h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])
            
            # 生成预测
            x_gen = self.conv_last(h_t[self.num_layers-1])
            next_frames.append(x_gen)
        
        # [length, batch, channel, height, width] -> [batch, length, channel, height, width]
        next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)
        
        return next_frames

4. 数据加载器实现

接下来,我们需要实现数据加载器,将数据集转换为模型可用的格式:

python 复制代码
def create_data_loaders(configs):
    """创建训练、验证和测试数据加载器"""
    
    # 数据转换
    if configs.data_augmentation:
        train_transform = nn.Sequential(
            RandomRotate(),
            RandomFlip()
        )
    else:
        train_transform = None
    
    # 创建数据集
    train_dataset = NPYDataset(
        data_root=configs.data_root,
        mode='train',
        input_frames=configs.input_length,
        output_frames=configs.total_length - configs.input_length,
        future_frames=configs.total_length - configs.input_length,
        transform=train_transform,
        preprocess=configs.preprocess_data
    )
    
    valid_dataset = NPYDataset(
        data_root=configs.data_root,
        mode='valid',
        input_frames=configs.input_length,
        output_frames=configs.total_length - configs.input_length,
        future_frames=configs.total_length - configs.input_length,
        transform=None,
        preprocess=configs.preprocess_data
    )
    
    test_dataset = NPYDataset(
        data_root=configs.data_root,
        mode='test',
        input_frames=configs.input_length,
        output_frames=configs.total_length - configs.input_length,
        future_frames=configs.total_length - configs.input_length,
        transform=None,
        preprocess=configs.preprocess_data
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=configs.batch_size,
        shuffle=True,
        num_workers=configs.num_workers,
        pin_memory=True
    )
    
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=configs.batch_size,
        shuffle=False,
        num_workers=configs.num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=configs.batch_size,
        shuffle=False,
        num_workers=configs.num_workers,
        pin_memory=True
    )
    
    return train_loader, valid_loader, test_loader

5. 模型训练流程

现在,我们实现完整的训练流程,包括损失函数、优化器和学习率调度器:

python 复制代码
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdm

class Trainer:
    def __init__(self, configs, model, train_loader, valid_loader, test_loader):
        self.configs = configs
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.device = configs.device
        
        # 损失函数
        self.criterion = nn.MSELoss()
        
        # 优化器
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=configs.lr,
            weight_decay=configs.weight_decay
        )
        
        # 学习率调度器
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=True
        )
        
        # 记录训练历史
        self.train_losses = []
        self.valid_losses = []
        self.best_loss = float('inf')
        
        # 创建检查点目录
        os.makedirs(configs.save_dir, exist_ok=True)
    
    def train_epoch(self, epoch):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
        
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            
            # 前向传播
            self.optimizer.zero_grad()
            outputs = self.model(inputs, mask_true=None)
            
            # 计算损失
            loss = self.criterion(outputs, targets)
            
            # 反向传播
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / len(self.train_loader)
        self.train_losses.append(avg_loss)
        
        return avg_loss
    
    def validate(self):
        """验证模型"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for inputs, targets in self.valid_loader:
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                outputs = self.model(inputs, mask_true=None)
                loss = self.criterion(outputs, targets)
                
                total_loss += loss.item()
        
        avg_loss = total_loss / len(self.valid_loader)
        self.valid_losses.append(avg_loss)
        
        return avg_loss
    
    def test(self):
        """测试模型"""
        self.model.eval()
        total_loss = 0
        all_outputs = []
        all_targets = []
        
        with torch.no_grad():
            for inputs, targets in self.test_loader:
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                outputs = self.model(inputs, mask_true=None)
                loss = self.criterion(outputs, targets)
                
                total_loss += loss.item()
                
                # 保存结果用于后续分析
                all_outputs.append(outputs.cpu().numpy())
                all_targets.append(targets.cpu().numpy())
        
        avg_loss = total_loss / len(self.test_loader)
        
        return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)
    
    def save_checkpoint(self, epoch, is_best=False):
        """保存检查点"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_losses': self.train_losses,
            'valid_losses': self.valid_losses,
            'best_loss': self.best_loss
        }
        
        # 保存最新检查点
        torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))
        
        # 如果是最佳模型,保存为最佳检查点
        if is_best:
            torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))
    
    def load_checkpoint(self, checkpoint_path):
        """加载检查点"""
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.valid_losses = checkpoint['valid_losses']
        self.best_loss = checkpoint['best_loss']
        
        return checkpoint['epoch']
    
    def train(self, num_epochs):
        """完整训练过程"""
        start_epoch = 0
        
        # 如果存在检查点,加载检查点
        if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):
            print("Loading checkpoint...")
            start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))
            print(f"Resumed from epoch {start_epoch}")
        
        for epoch in range(start_epoch, num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            
            # 训练
            train_loss = self.train_epoch(epoch)
            print(f"Train Loss: {train_loss:.6f}")
            
            # 验证
            valid_loss = self.validate()
            print(f"Valid Loss: {valid_loss:.6f}")
            
            # 更新学习率
            self.scheduler.step(valid_loss)
            
            # 保存检查点
            is_best = valid_loss < self.best_loss
            if is_best:
                self.best_loss = valid_loss
            
            self.save_checkpoint(epoch, is_best)
            
            # 每5个epoch测试一次
            if (epoch + 1) % 5 == 0:
                test_loss, _, _ = self.test()
                print(f"Test Loss: {test_loss:.6f}")
        
        # 最终测试
        print("\nFinal Testing...")
        test_loss, outputs, targets = self.test()
        print(f"Final Test Loss: {test_loss:.6f}")
        
        return test_loss, outputs, targets

6. 预测与结果可视化

实现预测功能和结果可视化:

python 复制代码
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

class Predictor:
    def __init__(self, configs, model):
        self.configs = configs
        self.model = model
        self.device = configs.device
        self.model.eval()
    
    def predict(self, input_seq):
        """预测未来帧"""
        with torch.no_grad():
            input_seq = input_seq.to(self.device)
            output_seq = self.model(input_seq, mask_true=None)
            return output_seq.cpu()
    
    def visualize_results(self, inputs, targets, predictions, save_path=None):
        """可视化输入、目标和预测结果"""
        # 选择第一个批次进行可视化
        inputs = inputs[0].squeeze()  # [T, H, W]
        targets = targets[0].squeeze()  # [T, H, W]
        predictions = predictions[0].squeeze()  # [T, H, W]
        
        # 创建子图
        total_frames = inputs.shape[0] + targets.shape[0]
        fig = plt.figure(figsize=(20, 10))
        grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)
        
        # 绘制输入帧
        for i in range(inputs.shape[0]):
            ax = grid[i]
            ax.imshow(inputs[i], cmap='viridis')
            ax.set_title(f'Input {i+1}')
            ax.axis('off')
        
        # 绘制目标帧
        for i in range(targets.shape[0]):
            ax = grid[inputs.shape[0] + i]
            ax.imshow(targets[i], cmap='viridis')
            ax.set_title(f'Target {i+1}')
            ax.axis('off')
        
        # 绘制预测帧
        for i in range(predictions.shape[0]):
            ax = grid[inputs.shape[0] + targets.shape[0] + i]
            ax.imshow(predictions[i], cmap='viridis')
            ax.set_title(f'Pred {i+1}')
            ax.axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def save_predictions(self, predictions, save_dir):
        """保存预测结果为NPY文件"""
        os.makedirs(save_dir, exist_ok=True)
        
        for i, pred_seq in enumerate(predictions):
            for j, frame in enumerate(pred_seq):
                frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')
                np.save(frame_path, frame.squeeze())
    
    def evaluate_metrics(self, targets, predictions):
        """评估预测性能"""
        from sklearn.metrics import mean_squared_error, mean_absolute_error
        
        # 展平数据
        targets_flat = targets.flatten()
        predictions_flat = predictions.flatten()
        
        # 计算指标
        mse = mean_squared_error(targets_flat, predictions_flat)
        mae = mean_absolute_error(targets_flat, predictions_flat)
        rmse = np.sqrt(mse)
        
        # 计算PSNR
        max_val = np.max(targets_flat)
        psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')
        
        # 计算SSIM (需要安装skimage)
        try:
            from skimage.metrics import structural_similarity as ssim_func
            ssim = ssim_func(targets_flat.reshape(targets.shape), 
                            predictions_flat.reshape(targets.shape),
                            data_range=max_val)
        except ImportError:
            ssim = 0
            print("SSIM calculation requires skimage. Install with: pip install scikit-image")
        
        return {
            'MSE': mse,
            'MAE': mae,
            'RMSE': rmse,
            'PSNR': psnr,
            'SSIM': ssim
        }

7. 模型评估与优化

实现模型评估和超参数优化功能:

python 复制代码
def hyperparameter_optimization(configs):
    """超参数优化"""
    import nni
    
    # 获取NNI超参数
    optimized_params = nni.get_next_parameter()
    configs.lr = optimized_params.get('lr', configs.lr)
    configs.batch_size = optimized_params.get('batch_size', configs.batch_size)
    configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)
    
    # 创建模型和数据加载器
    model = PredRNNv2(configs).to(configs.device)
    train_loader, valid_loader, test_loader = create_data_loaders(configs)
    
    # 训练模型
    trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)
    test_loss, _, _ = trainer.train(configs.epoch)
    
    # 报告最终结果
    nni.report_final_result(test_loss)
    
    return test_loss

def analyze_results(configs, outputs, targets):
    """分析预测结果"""
    predictor = Predictor(configs, None)
    metrics = predictor.evaluate_metrics(targets, outputs)
    
    print("Evaluation Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)
    plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)
    plt.xlabel('Sample Index')
    plt.ylabel('Value')
    plt.title('Predictions vs Targets')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)
    plt.show()
    
    return metrics

8. 完整代码实现

现在,我们将所有组件整合到一个完整的脚本中:

python 复制代码
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_results

def parse_args():
    parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')
    
    # 数据参数
    parser.add_argument('--data_root', type=str, default='./dataset', help='数据集根目录')
    parser.add_argument('--input_length', type=int, default=10, help='输入帧数')
    parser.add_argument('--total_length', type=int, default=20, help='总帧数(输入+预测)')
    parser.add_argument('--img_width', type=int, default=500, help='图像宽度')
    parser.add_argument('--img_height', type=int, default=500, help='图像高度')
    parser.add_argument('--img_channel', type=int, default=1, help='图像通道数')
    parser.add_argument('--preprocess_data', type=bool, default=True, help='是否预处理数据')
    parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用数据增强')
    
    # 模型参数
    parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每层隐藏单元数')
    parser.add_argument('--filter_size', type=int, default=5, help='滤波器大小')
    parser.add_argument('--stride', type=int, default=1, help='步长')
    parser.add_argument('--patch_size', type=int, default=1, help='补丁大小')
    parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用层归一化')
    parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向调度采样')
    
    # 训练参数
    parser.add_argument('--batch_size', type=int, default=4, help='批次大小')
    parser.add_argument('--lr', type=float, default=1e-3, help='学习率')
    parser.add_argument('--weight_decay', type=float, default=0, help='权重衰减')
    parser.add_argument('--epoch', type=int, default=100, help='训练轮数')
    parser.add_argument('--num_workers', type=int, default=4, help='数据加载工作线程数')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='设备')
    parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目录')
    parser.add_argument('--resume', type=bool, default=False, help='是否恢复训练')
    
    # 其他参数
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='运行模式')
    parser.add_argument('--checkpoint_path', type=str, default='', help='检查点路径')
    
    return parser.parse_args()

def main():
    # 解析参数
    configs = parse_args()
    
    # 创建保存目录
    os.makedirs(configs.save_dir, exist_ok=True)
    
    # 创建模型
    model = PredRNNv2(configs).to(configs.device)
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
    
    if configs.mode == 'train':
        # 创建数据加载器
        train_loader, valid_loader, test_loader = create_data_loaders(configs)
        
        # 创建训练器并开始训练
        trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)
        test_loss, outputs, targets = trainer.train(configs.epoch)
        
        # 分析结果
        analyze_results(configs, outputs, targets)
        
    elif configs.mode == 'test':
        # 加载检查点
        if configs.checkpoint_path:
            checkpoint = torch.load(configs.checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded checkpoint from {configs.checkpoint_path}")
        
        # 创建数据加载器
        _, _, test_loader = create_data_loaders(configs)
        
        # 测试模型
        trainer = Trainer(configs, model, None, None, test_loader)
        test_loss, outputs, targets = trainer.test()
        
        print(f"Test Loss: {test_loss:.6f}")
        
        # 分析结果
        metrics = analyze_results(configs, outputs, targets)
        
        # 保存结果
        np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)
        np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)
        
    elif configs.mode == 'predict':
        # 加载检查点
        if configs.checkpoint_path:
            checkpoint = torch.load(configs.checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded checkpoint from {configs.checkpoint_path}")
        
        # 创建预测器
        predictor = Predictor(configs, model)
        
        # 加载要预测的数据
        # 这里假设有一个单独的预测数据集
        predict_dataset = NPYDataset(
            data_root=configs.data_root,
            mode='predict',
            input_frames=configs.input_length,
            output_frames=configs.total_length - configs.input_length,
            future_frames=configs.total_length - configs.input_length,
            transform=None,
            preprocess=configs.preprocess_data
        )
        
        predict_loader = DataLoader(
            predict_dataset,
            batch_size=configs.batch_size,
            shuffle=False,
            num_workers=configs.num_workers,
            pin_memory=True
        )
        
        all_predictions = []
        all_inputs = []
        
        with torch.no_grad():
            for inputs, _ in predict_loader:
                inputs = inputs.to(configs.device)
                predictions = predictor.predict(inputs)
                
                all_predictions.append(predictions.numpy())
                all_inputs.append(inputs.cpu().numpy())
        
        all_predictions = np.concatenate(all_predictions, axis=0)
        all_inputs = np.concatenate(all_inputs, axis=0)
        
        # 保存预测结果
        output_dir = os.path.join(configs.save_dir, 'predictions')
        os.makedirs(output_dir, exist_ok=True)
        
        for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):
            # 保存输入序列
            for j, frame in enumerate(input_seq):
                frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')
                np.save(frame_path, frame.squeeze())
            
            # 保存预测序列
            for j, frame in enumerate(pred_seq):
                frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')
                np.save(frame_path, frame.squeeze())
        
        print(f"Predictions saved to {output_dir}")
        
        # 可视化一些结果
        if len(all_inputs) > 0:
            sample_idx = 0
            predictor.visualize_results(
                all_inputs[sample_idx:sample_idx+1],
                all_predictions[sample_idx:sample_idx+1],
                all_predictions[sample_idx:sample_idx+1],
                save_path=os.path.join(output_dir, 'sample_prediction.png')
            )

if __name__ == '__main__':
    main()

9. 常见问题与解决方案

9.1 内存不足问题

当处理 500×500 的大尺寸图像时,可能会遇到内存不足的问题。解决方案:

  1. 使用数据分块:将大图像分割成小块进行处理
  2. 降低批次大小:减少每次处理的样本数量
  3. 使用混合精度训练:使用半精度浮点数减少内存占用
python 复制代码
# 混合精度训练示例
from torch.cuda.amp import autocast, GradScaler

def train_epoch_with_amp(self, epoch):
    """使用混合精度训练一个epoch"""
    self.model.train()
    total_loss = 0
    scaler = GradScaler()
    
    progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
    
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)
        
        # 使用自动混合精度
        with autocast():
            outputs = self.model(inputs, mask_true=None)
            loss = self.criterion(outputs, targets)
        
        # 缩放损失并反向传播
        self.optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(self.optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(self.train_loader)
    self.train_losses.append(avg_loss)
    
    return avg_loss

9.2 训练不稳定问题

PredRNNv2 模型训练可能会不稳定,可以尝试以下方法:

  1. 梯度裁剪:防止梯度爆炸
  2. 学习率调度:动态调整学习率
  3. 权重初始化:使用合适的初始化方法
python 复制代码
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):
    """带梯度裁剪的训练"""
    self.model.train()
    total_loss = 0
    
    progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
    
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)
        
        self.optimizer.zero_grad()
        outputs = self.model(inputs, mask_true=None)
        loss = self.criterion(outputs, targets)
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)
        
        self.optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(self.train_loader)
    self.train_losses.append(avg_loss)
    
    return avg_loss

9.3 过拟合问题

当模型在训练集上表现良好但在验证集上表现不佳时,可能存在过拟合问题:

  1. 数据增强:增加数据多样性
  2. 正则化:使用 Dropout 或权重衰减
  3. 早停:在验证损失不再改善时停止训练
python 复制代码
# 早停实现
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        
        return self.early_stop

# 在训练循环中使用早停
early_stopping = EarlyStopping(patience=10)

for epoch in range(num_epochs):
    # 训练和验证...
    if early_stopping(valid_loss):
        print("Early stopping triggered")
        break

10. 总结与展望

本文详细介绍了如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们涵盖了从环境配置、数据预处理、模型构建到训练和评估的完整流程。

10.1 主要成果

  1. 完整的数据处理流程:实现了针对 NPY 格式数据的加载、预处理和增强功能
  2. PredRNNv2 模型复现:成功实现了 PredRNNv2 模型的核心组件和完整架构
  3. 训练框架:构建了完整的训练、验证和测试流程,包括损失函数、优化器和学习率调度
  4. 预测与可视化:实现了预测功能和结果可视化,便于分析模型性能
  5. 问题解决方案:提供了针对常见问题(内存不足、训练不稳定、过拟合)的解决方案

10.2 未来工作方向

  1. 模型优化:尝试更先进的视频预测模型,如 SimVP、PhyDNet 等
  2. 多模态融合:结合其他传感器数据(如气象数据、地理信息)提高预测精度
  3. 实时预测:优化模型推理速度,实现实时预测功能
  4. 不确定性量化:增加对预测结果不确定性的估计
  5. 部署优化:将模型部署到生产环境,支持大规模数据处理

通过本文的指导和代码实现,读者应该能够成功复现 PredRNNv2 模型,并在自己的数据集上进行训练和预测。希望这项工作能够为视频预测任务的研究和应用提供有价值的参考。

相关推荐
Pocker_Spades_A2 小时前
Python快速入门专业版(二十八):函数参数进阶:默认参数与可变参数(*args/**kwargs)
开发语言·python
努力努力再努力wz2 小时前
【c++进阶系列】:map和set的模拟实现(附模拟实现的源码)
java·linux·运维·开发语言·c++
budingxiaomoli2 小时前
AVL树知识总结
数据结构·算法
灵海之森2 小时前
从qwen3-next学习大模型前沿架构
人工智能
hui函数2 小时前
scrapy框架-day02
后端·爬虫·python·scrapy
jz-炸芯片的zero2 小时前
【Zephyr电源与功耗专题】14_BMS电池管理算法(三重验证机制实现高精度电量估算)
单片机·物联网·算法·zephyr·bms电源管理算法
Cloud Traveler3 小时前
8.FC平台模块梳理
java·linux·开发语言
星期天要睡觉3 小时前
计算机视觉(opencv)实战十八——图像透视转换
人工智能·opencv·计算机视觉
歪歪1003 小时前
webpack 配置文件中 mode 有哪些模式?
开发语言·前端·javascript·webpack·前端框架·node.js