early-stopping pytorch refs

1)https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

2)https://machinelearningmastery.com/managing-a-pytorch-training-process-with-checkpoints-and-early-stopping/

3)https://pytorch.org/ignite/generated/ignite.handlers.early_stopping.EarlyStopping.html

4)https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376

5)https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

复制代码
https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376Step-by-Step Guide in PyTorch
1.Import libraries
import torch
import numpy as np
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Upsample, Concatenate
from torch.optim import Adam
import copy
2. Define the U-Net Architecture

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()

        # Contracting path
        self.conv1 = Conv2d(input_channels, 64, 3, padding=1)
        self.conv2 = Conv2d(64, 64, 3, padding=1)
        self.pool = MaxPool2d(2, 2)
        self.conv3 = Conv2d(64, 128, 3, padding=1)
        self.conv4 = Conv2d(128, 128, 3, padding=1)
        self.conv5 = Conv2d(128, 256, 3, padding=1)
        self.conv6 = Conv2d(256, 256, 3, padding=1)

        # Expanding path
        self.up7 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = Conv2d(256, 128, 3, padding=1)
        self.conv8 = Conv2d(128, 128, 3, padding=1)
        self.up8 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = Conv2d(128, 64, 3, padding=1)
        self.conv10 = Conv2d(64, 64, 3, padding=1)

        # Output layer
        self.conv11 = nn.Conv2d(64, output_channels, 1)

    def forward(self, x):
        # Contracting path
        x1 = self.conv1(x)
        x1 = nn.functional.relu(x1)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = self.pool(x1)
        x2 = self.conv3(x1)
        x2 = nn.functional.relu(x2)
        x2 = self.conv4(x2)
        x2 = nn.functional.relu(x2)
        x2 = self.pool(x2)
        x3 = self.conv5(x2)
        x3 = nn.functional.relu(x3)
        x3 = self.conv6(x3)
        x3 = nn.functional.relu(x3)

        # Expanding path
        x4 = self.up7(x3)
        x4 = torch.cat([x4, x2], dim=1)  # Skip connection
        x4 = self.conv7(x4)
        x4 = nn.functional.relu(x4)
        x4 = self.conv8(x4)
        x4 = nn.functional.relu(x4)
        x5 = self.up8(x4)
        x5 = torch.cat([x5, x1], dim=1)  # Skip connection
        x5 = self.conv9(x5)
        x5 = nn.functional.relu(x5)
        x5 = self.conv10(x5)
        x5 = nn.functional.relu(x5)

        # Output layer
        output = self.conv11(x5)
        return output
3. Load your data

X_train = torch.from_numpy(np.load('your_training_images.npy'))
y_train = torch.from_numpy(np.load('your_training_segmentations.npy'))
X_val = torch.from_numpy(np.load('your_validation_images
4. Define HyperParameters

input_channels = X_train.shape[1]  # Adjust based on your image channels
output_channels = 1  # For binary segmentation
5. Create UNet model

model = UNet(input_channels, output_channels)
6. Initialize Optimizer and Loss Functions

optimizer = Adam(model.parameters())
criterion = nn.BCELoss()
7. Training loop with early stopping

#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience = 10

# Training Loop with Early Stopping:**
for epoch in range(100):
    # Set model to training mode
    model.train()

    # Forward pass and loss calculation
    outputs = model(X_train)
    loss = criterion(outputs, y_train.float())  # Convert y_train to float for BCELoss

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, y_val.float())

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())  # Deep copy here      
        patience = 10  # Reset patience counter
    else:
        patience -= 1
        if patience == 0:
            break

# Load the best model weights
model.load_state_dict(best_model_weights)
8. Inference

# Set model to evaluation mode
model.eval()

# Perform inference on new images
with torch.no_grad():
    new_images = torch.from_numpy(np.load('your_new_images.npy'))
    predictions = model(new_images)

# Process and visualize predictions as needed```
相关推荐
小鸡吃米…39 分钟前
机器学习 - K - 中心聚类
人工智能·机器学习·聚类
好奇龙猫1 小时前
【AI学习-comfyUI学习-第三十节-第三十一节-FLUX-SD放大工作流+FLUX图生图工作流-各个部分学习】
人工智能·学习
沈浩(种子思维作者)1 小时前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
minhuan1 小时前
大模型应用:大模型越大越好?模型参数量与效果的边际效益分析.51
人工智能·大模型参数评估·边际效益分析·大模型参数选择
Cherry的跨界思维2 小时前
28、AI测试环境搭建与全栈工具实战:从本地到云平台的完整指南
java·人工智能·vue3·ai测试·ai全栈·测试全栈·ai测试全栈
MM_MS2 小时前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
ASF1231415sd2 小时前
【基于YOLOv10n-CSP-PTB的大豆花朵检测与识别系统详解】
人工智能·yolo·目标跟踪
njsgcs2 小时前
ue python二次开发启动教程+ 导入fbx到指定文件夹
开发语言·python·unreal engine·ue
io_T_T2 小时前
迭代器 iteration、iter 与 多线程 concurrent 交叉实践(详细)
python
水如烟2 小时前
孤能子视角:“意识“的阶段性回顾,“感质“假说
人工智能