early-stopping pytorch refs

1)https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

2)https://machinelearningmastery.com/managing-a-pytorch-training-process-with-checkpoints-and-early-stopping/

3)https://pytorch.org/ignite/generated/ignite.handlers.early_stopping.EarlyStopping.html

4)https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376

5)https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch

复制代码
https://medium.com/@vrunda.bhattbhatt/a-step-by-step-guide-to-early-stopping-in-tensorflow-and-pytorch-59c1e3d0e376Step-by-Step Guide in PyTorch
1.Import libraries
import torch
import numpy as np
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Upsample, Concatenate
from torch.optim import Adam
import copy
2. Define the U-Net Architecture

class UNet(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNet, self).__init__()

        # Contracting path
        self.conv1 = Conv2d(input_channels, 64, 3, padding=1)
        self.conv2 = Conv2d(64, 64, 3, padding=1)
        self.pool = MaxPool2d(2, 2)
        self.conv3 = Conv2d(64, 128, 3, padding=1)
        self.conv4 = Conv2d(128, 128, 3, padding=1)
        self.conv5 = Conv2d(128, 256, 3, padding=1)
        self.conv6 = Conv2d(256, 256, 3, padding=1)

        # Expanding path
        self.up7 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = Conv2d(256, 128, 3, padding=1)
        self.conv8 = Conv2d(128, 128, 3, padding=1)
        self.up8 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = Conv2d(128, 64, 3, padding=1)
        self.conv10 = Conv2d(64, 64, 3, padding=1)

        # Output layer
        self.conv11 = nn.Conv2d(64, output_channels, 1)

    def forward(self, x):
        # Contracting path
        x1 = self.conv1(x)
        x1 = nn.functional.relu(x1)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = self.pool(x1)
        x2 = self.conv3(x1)
        x2 = nn.functional.relu(x2)
        x2 = self.conv4(x2)
        x2 = nn.functional.relu(x2)
        x2 = self.pool(x2)
        x3 = self.conv5(x2)
        x3 = nn.functional.relu(x3)
        x3 = self.conv6(x3)
        x3 = nn.functional.relu(x3)

        # Expanding path
        x4 = self.up7(x3)
        x4 = torch.cat([x4, x2], dim=1)  # Skip connection
        x4 = self.conv7(x4)
        x4 = nn.functional.relu(x4)
        x4 = self.conv8(x4)
        x4 = nn.functional.relu(x4)
        x5 = self.up8(x4)
        x5 = torch.cat([x5, x1], dim=1)  # Skip connection
        x5 = self.conv9(x5)
        x5 = nn.functional.relu(x5)
        x5 = self.conv10(x5)
        x5 = nn.functional.relu(x5)

        # Output layer
        output = self.conv11(x5)
        return output
3. Load your data

X_train = torch.from_numpy(np.load('your_training_images.npy'))
y_train = torch.from_numpy(np.load('your_training_segmentations.npy'))
X_val = torch.from_numpy(np.load('your_validation_images
4. Define HyperParameters

input_channels = X_train.shape[1]  # Adjust based on your image channels
output_channels = 1  # For binary segmentation
5. Create UNet model

model = UNet(input_channels, output_channels)
6. Initialize Optimizer and Loss Functions

optimizer = Adam(model.parameters())
criterion = nn.BCELoss()
7. Training loop with early stopping

#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience = 10

# Training Loop with Early Stopping:**
for epoch in range(100):
    # Set model to training mode
    model.train()

    # Forward pass and loss calculation
    outputs = model(X_train)
    loss = criterion(outputs, y_train.float())  # Convert y_train to float for BCELoss

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        val_outputs = model(X_val)
        val_loss = criterion(val_outputs, y_val.float())

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())  # Deep copy here      
        patience = 10  # Reset patience counter
    else:
        patience -= 1
        if patience == 0:
            break

# Load the best model weights
model.load_state_dict(best_model_weights)
8. Inference

# Set model to evaluation mode
model.eval()

# Perform inference on new images
with torch.no_grad():
    new_images = torch.from_numpy(np.load('your_new_images.npy'))
    predictions = model(new_images)

# Process and visualize predictions as needed```
相关推荐
果粒橙_LGC2 分钟前
论文阅读系列(一)Qwen-Image Technical Report
论文阅读·人工智能·学习
WSSWWWSSW4 分钟前
Matplotlib数据可视化实战:Matplotlib子图布局与管理入门
python·信息可视化·matplotlib
WSSWWWSSW5 分钟前
Matplotlib数据可视化实战:Matplotlib图表美化与进阶教程
python·信息可视化·matplotlib
mftang9 分钟前
Python可视化工具-Bokeh:动态显示数据
开发语言·python
雷达学弱狗17 分钟前
backward怎么计算的是torch.tensor(2.0, requires_grad=True)变量的梯度
人工智能·pytorch·深度学习
Seeklike19 分钟前
diffuxers学习--AutoPipeline
人工智能·python·stable diffusion·diffusers
前端小趴菜0526 分钟前
python - 数据类型
python
杨过过儿35 分钟前
【Task01】:简介与环境配置(第一章1、2节)
人工智能·自然语言处理
小妖同学学AI36 分钟前
deepseek一键生成word和excel并一键下载
人工智能·word·excel·deepseek
黎燃44 分钟前
AI助力垃圾分类与回收的可行性研究:从算法到落地的深度解析
人工智能