early-stopping pytorch refs

1)https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

2)https://machinelearningmastery.com/managing-a-pytorch-training-process-with-checkpoints-and-early-stopping/

3)https://pytorch.org/ignite/generated/ignite.handlers.early_stopping.EarlyStopping.html

4)https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376

5)https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

复制代码
https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376Step-by-Step Guide in PyTorch
1.Import libraries
import torch
import numpy as np
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Upsample, Concatenate
from torch.optim import Adam
import copy
2. Define the U-Net Architecture

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()

        # Contracting path
        self.conv1 = Conv2d(input_channels, 64, 3, padding=1)
        self.conv2 = Conv2d(64, 64, 3, padding=1)
        self.pool = MaxPool2d(2, 2)
        self.conv3 = Conv2d(64, 128, 3, padding=1)
        self.conv4 = Conv2d(128, 128, 3, padding=1)
        self.conv5 = Conv2d(128, 256, 3, padding=1)
        self.conv6 = Conv2d(256, 256, 3, padding=1)

        # Expanding path
        self.up7 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = Conv2d(256, 128, 3, padding=1)
        self.conv8 = Conv2d(128, 128, 3, padding=1)
        self.up8 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = Conv2d(128, 64, 3, padding=1)
        self.conv10 = Conv2d(64, 64, 3, padding=1)

        # Output layer
        self.conv11 = nn.Conv2d(64, output_channels, 1)

    def forward(self, x):
        # Contracting path
        x1 = self.conv1(x)
        x1 = nn.functional.relu(x1)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = self.pool(x1)
        x2 = self.conv3(x1)
        x2 = nn.functional.relu(x2)
        x2 = self.conv4(x2)
        x2 = nn.functional.relu(x2)
        x2 = self.pool(x2)
        x3 = self.conv5(x2)
        x3 = nn.functional.relu(x3)
        x3 = self.conv6(x3)
        x3 = nn.functional.relu(x3)

        # Expanding path
        x4 = self.up7(x3)
        x4 = torch.cat([x4, x2], dim=1)  # Skip connection
        x4 = self.conv7(x4)
        x4 = nn.functional.relu(x4)
        x4 = self.conv8(x4)
        x4 = nn.functional.relu(x4)
        x5 = self.up8(x4)
        x5 = torch.cat([x5, x1], dim=1)  # Skip connection
        x5 = self.conv9(x5)
        x5 = nn.functional.relu(x5)
        x5 = self.conv10(x5)
        x5 = nn.functional.relu(x5)

        # Output layer
        output = self.conv11(x5)
        return output
3. Load your data

X_train = torch.from_numpy(np.load('your_training_images.npy'))
y_train = torch.from_numpy(np.load('your_training_segmentations.npy'))
X_val = torch.from_numpy(np.load('your_validation_images
4. Define HyperParameters

input_channels = X_train.shape[1]  # Adjust based on your image channels
output_channels = 1  # For binary segmentation
5. Create UNet model

model = UNet(input_channels, output_channels)
6. Initialize Optimizer and Loss Functions

optimizer = Adam(model.parameters())
criterion = nn.BCELoss()
7. Training loop with early stopping

#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience = 10

# Training Loop with Early Stopping:**
for epoch in range(100):
    # Set model to training mode
    model.train()

    # Forward pass and loss calculation
    outputs = model(X_train)
    loss = criterion(outputs, y_train.float())  # Convert y_train to float for BCELoss

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, y_val.float())

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())  # Deep copy here      
        patience = 10  # Reset patience counter
    else:
        patience -= 1
        if patience == 0:
            break

# Load the best model weights
model.load_state_dict(best_model_weights)
8. Inference

# Set model to evaluation mode
model.eval()

# Perform inference on new images
with torch.no_grad():
    new_images = torch.from_numpy(np.load('your_new_images.npy'))
    predictions = model(new_images)

# Process and visualize predictions as needed```
相关推荐
霍格沃兹_测试4 分钟前
Ollama + Python 极简工作流
人工智能
资源开发与学习5 分钟前
AI智时代:一节课带你玩转 Cursor,开启快速入门与实战之旅
人工智能
强盛小灵通专卖员8 分钟前
闪电科创 SCI专业辅导
python·深度强化学习·研究生·ei会议·导师·sci期刊
跟橙姐学代码19 分钟前
自动化邮件发送的终极秘籍:Python库smtplib与email的完整玩法
前端·python·ipython
西安光锐软件19 分钟前
深度学习之损失函数
人工智能·深度学习
补三补四19 分钟前
LSTM 深度解析:从门控机制到实际应用
人工智能·rnn·lstm
astragin23 分钟前
神经网络常见层速查表
人工智能·深度学习·神经网络
嘀咕博客33 分钟前
文心快码Comate - 百度推出的AI编码助手
人工智能·百度·ai工具
cyyt42 分钟前
深度学习周报(9.8~9.14)
人工智能·深度学习
扯淡的闲人1 小时前
多语言编码Agent解决方案(2)-后端服务实现
开发语言·python·深度学习