early-stopping pytorch refs

1)https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

2)https://machinelearningmastery.com/managing-a-pytorch-training-process-with-checkpoints-and-early-stopping/

3)https://pytorch.org/ignite/generated/ignite.handlers.early_stopping.EarlyStopping.html

4)https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376

5)https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376Step-by-Step Guide in PyTorch
1.Import libraries
import torch
import numpy as np
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Upsample, Concatenate
from torch.optim import Adam
import copy
2. Define the U-Net Architecture

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()

        # Contracting path
        self.conv1 = Conv2d(input_channels, 64, 3, padding=1)
        self.conv2 = Conv2d(64, 64, 3, padding=1)
        self.pool = MaxPool2d(2, 2)
        self.conv3 = Conv2d(64, 128, 3, padding=1)
        self.conv4 = Conv2d(128, 128, 3, padding=1)
        self.conv5 = Conv2d(128, 256, 3, padding=1)
        self.conv6 = Conv2d(256, 256, 3, padding=1)

        # Expanding path
        self.up7 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = Conv2d(256, 128, 3, padding=1)
        self.conv8 = Conv2d(128, 128, 3, padding=1)
        self.up8 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = Conv2d(128, 64, 3, padding=1)
        self.conv10 = Conv2d(64, 64, 3, padding=1)

        # Output layer
        self.conv11 = nn.Conv2d(64, output_channels, 1)

    def forward(self, x):
        # Contracting path
        x1 = self.conv1(x)
        x1 = nn.functional.relu(x1)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = self.pool(x1)
        x2 = self.conv3(x1)
        x2 = nn.functional.relu(x2)
        x2 = self.conv4(x2)
        x2 = nn.functional.relu(x2)
        x2 = self.pool(x2)
        x3 = self.conv5(x2)
        x3 = nn.functional.relu(x3)
        x3 = self.conv6(x3)
        x3 = nn.functional.relu(x3)

        # Expanding path
        x4 = self.up7(x3)
        x4 = torch.cat([x4, x2], dim=1)  # Skip connection
        x4 = self.conv7(x4)
        x4 = nn.functional.relu(x4)
        x4 = self.conv8(x4)
        x4 = nn.functional.relu(x4)
        x5 = self.up8(x4)
        x5 = torch.cat([x5, x1], dim=1)  # Skip connection
        x5 = self.conv9(x5)
        x5 = nn.functional.relu(x5)
        x5 = self.conv10(x5)
        x5 = nn.functional.relu(x5)

        # Output layer
        output = self.conv11(x5)
        return output
3. Load your data

X_train = torch.from_numpy(np.load('your_training_images.npy'))
y_train = torch.from_numpy(np.load('your_training_segmentations.npy'))
X_val = torch.from_numpy(np.load('your_validation_images
4. Define HyperParameters

input_channels = X_train.shape[1]  # Adjust based on your image channels
output_channels = 1  # For binary segmentation
5. Create UNet model

model = UNet(input_channels, output_channels)
6. Initialize Optimizer and Loss Functions

optimizer = Adam(model.parameters())
criterion = nn.BCELoss()
7. Training loop with early stopping

#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience = 10

# Training Loop with Early Stopping:**
for epoch in range(100):
    # Set model to training mode
    model.train()

    # Forward pass and loss calculation
    outputs = model(X_train)
    loss = criterion(outputs, y_train.float())  # Convert y_train to float for BCELoss

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, y_val.float())

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())  # Deep copy here      
        patience = 10  # Reset patience counter
    else:
        patience -= 1
        if patience == 0:
            break

# Load the best model weights
model.load_state_dict(best_model_weights)
8. Inference

# Set model to evaluation mode
model.eval()

# Perform inference on new images
with torch.no_grad():
    new_images = torch.from_numpy(np.load('your_new_images.npy'))
    predictions = model(new_images)

# Process and visualize predictions as needed```
相关推荐
小嗷犬6 分钟前
【论文笔记】VCoder: Versatile Vision Encoders for Multimodal Large Language Models
论文阅读·人工智能·语言模型·大模型·多模态
Struart_R11 分钟前
LVSM: A LARGE VIEW SYNTHESIS MODEL WITH MINIMAL 3D INDUCTIVE BIAS 论文解读
人工智能·3d·transformer·三维重建
lucy1530275107913 分钟前
【青牛科技】GC5931:工业风扇驱动芯片的卓越替代者
人工智能·科技·单片机·嵌入式硬件·算法·机器学习
哇咔咔哇咔25 分钟前
【科普】conda、virtualenv, venv分别是什么?它们之间有什么区别?
python·conda·virtualenv
幻风_huanfeng39 分钟前
线性代数中的核心数学知识
人工智能·机器学习
CSXB991 小时前
三十四、Python基础语法(文件操作-上)
开发语言·python·功能测试·测试工具
volcanical1 小时前
LangGPT结构化提示词编写实践
人工智能
weyson1 小时前
CSharp OpenAI
人工智能·语言模型·chatgpt·openai
RestCloud1 小时前
ETLCloud异常问题分析ai功能
人工智能·ai·数据分析·etl·数据集成工具·数据异常
亚图跨际2 小时前
MATLAB和Python及R潜变量模型和降维
python·matlab·r语言·生物学·潜变量模型