训练地理空间深度学习的 UNet 模型:Google Colab 中的分步指南

在这篇文章中,我将介绍如何使用 Google Colab 中的 Landcover.ai 数据集实现 UNet 架构进行地理空间图像分割

什么是 UNet?

UNet 是一种卷积神经网络架构,专为生物医学图像分割而设计。其结构类似于"U"形(因此得名),其中收缩路径(编码器)和扩张路径(解码器)通过跳跃连接连接。这些跳跃连接可帮助网络保留在下采样过程中可能丢失的空间信息。

虽然 UNet 最初是为生物医学应用而开发的,但已被证明对各种图像分割任务非常有效,包括卫星和航空图像分析。

项目概况

该项目实现了一个 UNet 模型,对 Landcover.ai 数据集执行语义分割,该数据集包含标有五类的高分辨率航空图像:

  • 背景
  • 建筑
  • 兀兰
  • 道路

目标是训练一个模型,可以从航空图像中自动识别和分类这些不同的土地覆盖类型。

实施步骤

  1. 准备步骤

我们首先安装项目所需的库:

  • 数据库(用于数据增强)

  • torch、torchinfo、torchmetrics(PyTorch 和相关工具)

  • kornia(计算机视觉库)

  • opencv-python(图像处理)

    复制代码
    # Install required packages
    !pip install albumentations
    !pip install torch torchinfo torchmetrics kornia opencv-python
    import numpy as np
    import pandas as pd
    import matplotlib
    from matplotlib import pyplot as plt
    
    import os
    import cv2
    import albumentations as A
    
    import torch
    import torch.nn as nn
    from torch.nn import functional as F
    from torch.utils.data.dataset import Dataset
    from torch.utils.data import DataLoader
    
    from torchinfo import summary
    import torchmetrics as tm
    from kornia import losses
    
    # Check if CUDA is available and set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    2. 下载并准备数据集 Landcover.ai 数据集由高分辨率正射影像和相应的分割掩模组成。下载数据集后,我们将其分割成 512×512 像素的较小图像芯片,以使其更易于训练:

    复制代码
    # Create directories
    !mkdir -p landcover_data/images
    !mkdir -p landcover_data/masks
    !mkdir -p landcover_data/chips
    
    # Download the Landcover.ai dataset
    !wget -q https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip
    !unzip -q landcover.ai.v1.zip -d landcover_data
    
    # Check if files were downloaded successfully
    print("Files in landcover_data directory:")
    !ls -la landcover_data

    3. 创建芯片生成脚本

    原始正射影像非常大,因此我们将其分割成更小的 512×512 像素芯片。这不仅使训练更有效率,而且还创建了更多的训练样本:

    复制代码
    import glob
    import os
    import cv2
    
    # Define directories
    IMGS_DIR = "landcover_data/images/"
    MASKS_DIR = "landcover_data/masks/"
    OUTPUT_DIR = "landcover_data/chips/"
    
    TARGET_SIZE = 512
    
    # Find all image and mask files
    img_paths = glob.glob(os.path.join(IMGS_DIR, "*.tif"))
    mask_paths = glob.glob(os.path.join(MASKS_DIR, "*.tif"))
    
    img_paths.sort()
    mask_paths.sort()
    
    # Check if we found images and masks
    print(f"Found {len(img_paths)} images and {len(mask_paths)} masks")
    
    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Define a function to process a limited number of images for testing, for less number of orthophotos change value of limit
    def process_images(img_paths, mask_paths, limit=41):
        for i, (img_path, mask_path) in enumerate(zip(img_paths[:limit], mask_paths[:limit])):
            img_filename = os.path.splitext(os.path.basename(img_path))[0]
            mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
            img = cv2.imread(img_path)
            mask = cv2.imread(mask_path)
    
            # Skip if either image or mask couldn't be read
            if img is None or mask is None:
                print(f"Warning: Could not read {img_path} or {mask_path}")
                continue
    
            assert img_filename == mask_filename and img.shape[:2] == mask.shape[:2]
    
            k = 0
            for y in range(0, img.shape[0], TARGET_SIZE):
                for x in range(0, img.shape[1], TARGET_SIZE):
                    img_tile = img[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
                    mask_tile = mask[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
    
                    if img_tile.shape[0] == TARGET_SIZE and img_tile.shape[1] == TARGET_SIZE:
                        out_img_path = os.path.join(OUTPUT_DIR, "{}_{}.jpg".format(img_filename, k))
                        cv2.imwrite(out_img_path, img_tile)
    
                        out_mask_path = os.path.join(OUTPUT_DIR, "{}_{}_m.png".format(mask_filename, k))
                        cv2.imwrite(out_mask_path, mask_tile)
    
                    k += 1
    
            print(f"Processed {img_filename} ({i + 1}/{min(limit, len(img_paths))})")
    
    # Process a limited number of images first to test
    # process_images(img_paths, mask_paths, limit=2)
    # Process all images since test worked well
    process_images(img_paths, mask_paths)
    
    # Check if chips were created
    print("Generated chips:")
    !ls -la landcover_data/chips/ | head

    4. 创建自定义训练/验证/测试分割我们将数据集分为训练集(70%)、验证集(15%)和测试集(15%):

    复制代码
    # Create our own train/val/test splits based on generated chips
    import random
    
    # Get list of all image files (not masks)
    all_files_in_chips_folder=[f for f in os.listdir(OUTPUT_DIR)]
    print(f"Total files found in {OUTPUT_DIR} folder: {len(all_files_in_chips_folder)}")
    all_image_chips = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.jpg')]
    print(f"Total image chips found: {len(all_image_chips)}")
    all_mask_chips = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('_m.png')]
    print(f"Total mask chips found: {len(all_mask_chips)}")
    
    # Shuffle the list for randomization
    random.seed(42)  # For reproducibility
    random.shuffle(all_image_chips)
    
    # Split into train/val/test (70%/15%/15%)
    train_size = int(0.7 * len(all_image_chips))
    val_size = int(0.15 * len(all_image_chips))
    
    train_files = all_image_chips[:train_size]
    val_files = all_image_chips[train_size:train_size+val_size]
    test_files = all_image_chips[train_size+val_size:]
    
    print(f"Split sizes - Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
    
    # Create new txt files with our splits
    with open("landcover_data/train.txt", "w") as f:
        for file in train_files:
            f.write(f"{os.path.splitext(file)[0]}\n")
    
    with open("landcover_data/val.txt", "w") as f:
        for file in val_files:
            f.write(f"{os.path.splitext(file)[0]}\n")
    
    with open("landcover_data/test.txt", "w") as f:
        for file in test_files:
            f.write(f"{os.path.splitext(file)[0]}\n")
    
    print("Created new train/val/test split files")

    5. 创建数据集类别列表并处理训练/验证/测试分割

    复制代码
    # Define our class names
    CLASSES = ['background', 'building', 'woodlands', 'water', 'road']
    OUTPUT_DIR = "landcover_data/chips/"
    
    # Check if the train/val/test split files exist in the downloaded dataset
    !ls -la landcover_data/*.txt
    
    # Now read these files into DataFrames
    trainDF = pd.read_csv("landcover_data/train.txt", header=None, names=["file"])
    trainDF["img"] = OUTPUT_DIR + trainDF['file'] + ".jpg"
    trainDF["mask"] = OUTPUT_DIR + trainDF['file'] + "_m.png"
    
    valDF = pd.read_csv("landcover_data/val.txt", header=None, names=["file"])
    valDF["img"] = OUTPUT_DIR + valDF['file'] + ".jpg"
    valDF["mask"] = OUTPUT_DIR + valDF['file'] + "_m.png"
    
    testDF = pd.read_csv("landcover_data/test.txt", header=None, names=["file"])
    testDF["img"] = OUTPUT_DIR + testDF['file'] + ".jpg"
    testDF["mask"] = OUTPUT_DIR + testDF['file'] + "_m.png"
    
    # Display the first few rows of the training DataFrame
    print("Training DataFrame sample:")
    print(trainDF.head())
    
    # Check if the image and mask files exist
    print("\nChecking if files exist:")
    print(f"First training image exists: {os.path.exists(trainDF['img'].iloc[0])}")
    print(f"First training mask exists: {os.path.exists(trainDF['mask'].iloc[0])}")
    # Function to display a few samples from the training data
    def display_samples(df, num_samples=3):
        plt.figure(figsize=(15, 5*num_samples))
    
        for i in range(min(num_samples, len(df))):
            # Read image and mask
            img_path = df['img'].iloc[i]
            mask_path = df['mask'].iloc[i]
    
            try:
                # Read and convert image
                image = cv2.imread(img_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
                # Read mask
                mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
                if mask is not None:
                    mask = mask[:,:,0]  # Take first channel
    
                # Display
                plt.subplot(num_samples, 2, 2*i+1)
                plt.title(f"Image {i+1}: {os.path.basename(img_path)}")
                plt.imshow(image)
                plt.axis('off')
    
                plt.subplot(num_samples, 2, 2*i+2)
                plt.title(f"Mask {i+1}")
                if mask is not None:
                    plt.imshow(mask)
                else:
                    plt.text(0.5, 0.5, "Mask not found", horizontalalignment='center', verticalalignment='center')
                plt.axis('off')
    
            except Exception as e:
                plt.subplot(num_samples, 2, 2*i+1)
                plt.text(0.5, 0.5, f"Error loading image: {str(e)}", horizontalalignment='center', verticalalignment='center')
                plt.axis('off')
    
                plt.subplot(num_samples, 2, 2*i+2)
                plt.text(0.5, 0.5, f"Error loading mask: {str(e)}", horizontalalignment='center', verticalalignment='center')
                plt.axis('off')
    
        plt.tight_layout()
        plt.show()
    
    # Display samples from the training set
    print("Displaying samples from the training dataset:")
    display_samples(trainDF, num_samples=4)

让我们检查一下 Landcover.ai 数据集中不同的掩膜颜色及其含义:

复制代码
# Let's examine the mask values and their meanings
def explore_mask_values(df, num_samples=5):
    # Define the class names and colors for visualization
    class_names = ['background', 'building', 'woodland', 'water', 'road']
    class_colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]

    unique_values = set()

    print("Examining mask values in sample images:")

    for i in range(min(num_samples, len(df))):
        mask_path = df['mask'].iloc[i]

        try:
            # Read mask
            mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
            if mask is not None:
                # Take first channel (since masks are single-channel)
                mask_channel = mask[:,:,0]

                # Get unique values
                values = np.unique(mask_channel)
                unique_values.update(values)

                print(f"Mask {i+1}: {os.path.basename(mask_path)}")
                print(f"  Unique values: {values}")

                # Count pixels for each class
                for val in values:
                    if val < len(class_names):
                        class_name = class_names[val]
                        count = np.sum(mask_channel == val)
                        percentage = (count / mask_channel.size) * 100
                        print(f"  Class {val} ({class_name}): {count} pixels ({percentage:.2f}%)")
                print()

                # Display the mask with color-coding
                plt.figure(figsize=(10, 10))

                # Original mask
                plt.subplot(1, 2, 1)
                plt.title(f"Original Mask {i+1}")
                plt.imshow(mask_channel)
                plt.colorbar(label='Class ID')

                # Colored mask
                plt.subplot(1, 2, 2)
                plt.title(f"Color-coded Mask {i+1}")

                # Create RGB mask for visualization
                h, w = mask_channel.shape
                rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)

                for cls_id, color in enumerate(class_colors):
                    if cls_id < len(class_colors):  # Ensure we don't go out of bounds
                        rgb_mask[mask_channel == cls_id] = color

                plt.imshow(rgb_mask)

                # Add legend
                legend_elements = [plt.Rectangle((0, 0), 1, 1, color=np.array(color)/255)
                                for color in class_colors[:len(class_names)]]
                plt.legend(legend_elements, class_names, loc='upper right')

                plt.tight_layout()
                plt.show()

        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")

    print(f"All unique values found across {num_samples} masks: {sorted(unique_values)}")

# Explore the mask values
explore_mask_values(trainDF, num_samples=3)

6. 创建自定义数据集、DataLoader 类和数据增强 现在,让我们实现自定义数据集类和数据增强的转换:

复制代码
# Define the custom dataset class with better error handling
class MultiClassSegDataset(Dataset):
    def __init__(self, df, transform=None):
        # Filter out rows with non-existent files
        valid_rows = []
        for i, row in df.iterrows():
            if os.path.exists(row['img']) and os.path.exists(row['mask']):
                valid_rows.append(i)

        self.df = df.iloc[valid_rows].reset_index(drop=True)
        print(f"Found {len(self.df)} valid image-mask pairs out of {len(df)} entries")

        self.transform = transform

    def __getitem__(self, idx):
        image_name = self.df.iloc[idx, 1]
        mask_name = self.df.iloc[idx, 2]

        # Read image and mask with error checking
        image = cv2.imread(image_name)
        if image is None:
            raise ValueError(f"Failed to read image: {image_name}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
        if mask is None:
            raise ValueError(f"Failed to read mask: {mask_name}")

        # Convert to appropriate types
        image = image.astype('uint8')
        mask = mask[:,:,0]  # Take first channel of mask

        # Apply transformations if provided
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        # Convert to tensors
        image = torch.from_numpy(image)
        mask = torch.from_numpy(mask)

        # Convert image to channels-first format and normalize
        image = image.permute(2, 0, 1)
        image = image.float()/255

        # Convert mask to long type
        mask = mask.long()

        return image, mask

    def __len__(self):
        return len(self.df)

# Verify that the file paths in the DataFrames are correct
print("Checking file paths in DataFrames:")
print(f"Example training image path: {trainDF['img'].iloc[0]}")
print(f"Example training mask path: {trainDF['mask'].iloc[0]}")
print(f"File exists: {os.path.exists(trainDF['img'].iloc[0])}")

# Check and perhaps modify file paths if needed
# For example, if paths have double slashes or other issues:
trainDF['img'] = trainDF['img'].apply(lambda x: os.path.normpath(x))
trainDF['mask'] = trainDF['mask'].apply(lambda x: os.path.normpath(x))
valDF['img'] = valDF['img'].apply(lambda x: os.path.normpath(x))
valDF['mask'] = valDF['mask'].apply(lambda x: os.path.normpath(x))
testDF['img'] = testDF['img'].apply(lambda x: os.path.normpath(x))
testDF['mask'] = testDF['mask'].apply(lambda x: os.path.normpath(x))

# Define transforms for validation and test sets
test_transform = A.Compose([
    A.PadIfNeeded(min_height=512, min_width=512, border_mode=4),
    A.Resize(512, 512),
])

# Define transforms for training set (with augmentations)
train_transform = A.Compose([
    A.PadIfNeeded(min_height=512, min_width=512, border_mode=4),
    A.Resize(512, 512),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.MedianBlur(blur_limit=3, p=0.1),  # Removed 'always_apply' parameter
])

# Initialize datasets with error handling
trainDS = MultiClassSegDataset(trainDF, transform=train_transform)
valDS = MultiClassSegDataset(valDF, transform=test_transform)
testDS = MultiClassSegDataset(testDF, transform=test_transform)

# Print dataset sizes
print(f"Final Training Samples: {len(trainDS)}")
print(f"Final Validation Samples: {len(valDS)}")
print(f"Final Testing Samples: {len(testDS)}")

# Initialize DataLoaders with small batch size initially
trainDL = DataLoader(trainDS, batch_size=4, shuffle=True, num_workers=0, pin_memory=False, drop_last=True)
valDL = DataLoader(valDS, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, drop_last=True)
testDL = DataLoader(testDS, batch_size=4, shuffle=False, num_workers=0, pin_memory=False, drop_last=True)

# Try to get a single batch
try:
    batch = next(iter(trainDL))
    images, labels = batch
    print(f"Successfully loaded a batch!")
    print(f"Batch shapes - Images: {images.shape}, Labels: {labels.shape}")
    print(f"Data types - Images: {images.dtype}, Labels: {labels.dtype}")

    # Display a sample image and mask
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.title("Sample Image")
    plt.imshow(images[0].permute(1, 2, 0))  # Convert back to channels-last for display

    plt.subplot(1, 2, 2)
    plt.title("Sample Mask")
    plt.imshow(labels[0])

    plt.show()
except Exception as e:
    print(f"Error loading batch: {str(e)}")

    # If we failed, try to identify the issue
    print("\nInvestigating the issue:")
    sample_idx = 0
    print(f"Trying to load image: {trainDS.df['img'].iloc[sample_idx]}")
    try:
        img = cv2.imread(trainDS.df['img'].iloc[sample_idx])
        if img is None:
            print(f"cv2.imread returned None - file might not exist or has format issues")
        else:
            print(f"Successfully loaded image with shape: {img.shape}")
    except Exception as e:
        print(f"Error: {str(e)}")

O/P 检查 DataFrames 中的文件路径:示例训练图像路径:landcover_data/chips/M-34--65-Da-4--4_57.jpg 示例训练掩码路径:landcover_data/chips/M-34--65-Da-4--4_57_m.png 文件存在:True 在 7471 个条目中发现 7471 个有效图像掩码对 在 1601 个条目中发现 1601 个有效图像掩码对 在 1602 个条目中发现 1602 个有效图像掩码对 最终训练样本:7471 最终验证样本:1601 最终测试样本:1602 成功加载批次!批次形状 --- 图像:torch.Size([4, 3, 512, 512]),标签:torch.Size([4, 512, 512]) 数据类型 --- 图像:torch.float32,标签:torch.int64

7. 定义 UNet 架构

复制代码
# Helper functions for UNet architecture
def double_conv(inChannels, outChannels):
    return nn.Sequential(
        nn.Conv2d(inChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
        nn.BatchNorm2d(outChannels),
        nn.ReLU(inplace=True),
        nn.Conv2d(outChannels, outChannels, kernel_size=(3,3), stride=1, padding=1),
        nn.BatchNorm2d(outChannels),
        nn.ReLU(inplace=True)
    )

def up_conv(inChannels, outChannels):
    return nn.Sequential(
        nn.ConvTranspose2d(inChannels, outChannels, kernel_size=(2,2), stride=2),
        nn.BatchNorm2d(outChannels),
        nn.ReLU(inplace=True)
    )

# UNet model architecture
class myUNet(nn.Module):
    def __init__(self, encoderChn, decoderChn, inChn, botChn, nCls):
        super().__init__()
        self.encoderChn = encoderChn
        self.decoderChn = decoderChn
        self.botChn = botChn
        self.nCls = nCls

        # Encoder path
        self.encoder1 = double_conv(inChn, encoderChn[0])

        self.encoder2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            double_conv(encoderChn[0], encoderChn[1])
        )

        self.encoder3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            double_conv(encoderChn[1], encoderChn[2])
        )

        self.encoder4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            double_conv(encoderChn[2], encoderChn[3])
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            double_conv(encoderChn[3], botChn)
        )

        # Decoder path with skip connections
        self.decoder1up = up_conv(botChn, botChn)
        self.decoder1 = double_conv(encoderChn[3]+botChn, decoderChn[0])

        self.decoder2up = up_conv(decoderChn[0], decoderChn[0])
        self.decoder2 = double_conv(encoderChn[2]+decoderChn[0], decoderChn[1])

        self.decoder3up = up_conv(decoderChn[1], decoderChn[1])
        self.decoder3 = double_conv(encoderChn[1]+decoderChn[1], decoderChn[2])

        self.decoder4up = up_conv(decoderChn[2], decoderChn[2])
        self.decoder4 = double_conv(encoderChn[0]+decoderChn[2], decoderChn[3])

        # Final classifier
        self.classifier = nn.Conv2d(decoderChn[3], nCls, kernel_size=(1,1))

    def forward(self, x):
        # Encoder
        encoder1 = self.encoder1(x)
        encoder2 = self.encoder2(encoder1)
        encoder3 = self.encoder3(encoder2)
        encoder4 = self.encoder4(encoder3)

        # Bottleneck
        x = self.bottleneck(encoder4)

        # Decoder with skip connections
        x = self.decoder1up(x)
        x = torch.concat([x, encoder4], dim=1)
        x = self.decoder1(x)

        x = self.decoder2up(x)
        x = torch.concat([x, encoder3], dim=1)
        x = self.decoder2(x)

        x = self.decoder3up(x)
        x = torch.concat([x, encoder2], dim=1)
        x = self.decoder3(x)

        x = self.decoder4up(x)
        x = torch.concat([x, encoder1], dim=1)
        x = self.decoder4(x)

        # Classifier head
        x = self.classifier(x)

        return x

# Instantiate the model
model = myUNet(
    encoderChn=[16, 32, 64, 128],
    decoderChn=[128, 64, 32, 16],
    inChn=3,
    botChn=512,
    nCls=5
).to(device)

# Print model summary
summary(model, (4, 3, 512, 512))  # Match our batch size of 4

8.设置损失函数、优化器和指标

复制代码
# Define loss function - Dice Loss from kornia
criterion = losses.DiceLoss(average="macro")

# Define optimizer with AdamW and initial learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

# Define learning rate scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.01,
    epochs=50,
    steps_per_epoch=len(trainDL),
    three_phase=True
)

# Define evaluation metrics
acc = tm.Accuracy(task="multiclass", num_classes=5, average="micro").to(device)
f1 = tm.F1Score(task="multiclass", num_classes=5, average="macro").to(device)
kappa = tm.CohenKappa(task="multiclass", num_classes=5).to(device)

# Define the number of epochs and save folder
epochs = 50
save_folder = "model_checkpoints/"
os.makedirs(save_folder, exist_ok=True)

# Initialize lists to store metrics
eNum = []
t_loss = []
t_acc = []
t_f1 = []
t_kappa = []
v_loss = []
v_acc = []
v_f1 = []
v_kappa = []

# Initialize best validation F1 score
f1VMax = 0.0

print("Setup complete for training!")
print(f"Training for {epochs} epochs")
print(f"Model checkpoints will be saved to: {save_folder}")
print(f"Loss function: {criterion.__class__.__name__}")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate scheduler: {scheduler.__class__.__name__}")

输出

训练设置完成!;训练 50 个时期模型检查点将保存到:model_checkpoints/损失函数:DiceLoss优化器:AdamW学习率调度程序:OneCycleLR

9. 训练循环

复制代码
# For demonstration purposes, let's set a smaller number of epochs
demo_epochs = 20  # You can increase this if you want to train longer

print(f"Starting training for {demo_epochs} epochs...")

# Loop over epochs
for epoch in range(1, demo_epochs + 1):
    # Initialize running loss for epoch
    running_loss = 0.0

    # Make sure model is in training mode
    model.train()

    # Loop over training batches
    for batch_idx, (inputs, targets) in enumerate(trainDL):
        # Get data and move to device
        inputs, targets = inputs.to(device), targets.to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Calculate loss
        loss = criterion(outputs, targets)

        # Calculate metrics
        acc_val = acc(outputs, targets)
        f1_val = f1(outputs, targets)
        kappa_val = kappa(outputs, targets)

        # Backward pass
        loss.backward()

        # Update parameters
        optimizer.step()

        # Update learning rate
        scheduler.step()

        # Update running loss with batch results
        running_loss += loss.item()

        # Print progress every 5 batches
        if (batch_idx + 1) % 5 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx + 1}/{len(trainDL)}, Loss: {loss.item():.4f}')

    # Accumulate loss and metrics at end of training epoch
    epoch_loss = running_loss / len(trainDL)
    acc_train = acc.compute()
    f1_train = f1.compute()
    kappa_train = kappa.compute()

    # Print losses and metrics at end of each training epoch
    print(f'Epoch: {epoch}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {acc_train:.4f}, Training F1: {f1_train:.4f}, Training Kappa: {kappa_train:.4f}')

    # Append results
    eNum.append(epoch)
    t_loss.append(epoch_loss)
    t_acc.append(acc_train.detach().cpu().numpy())
    t_f1.append(f1_train.detach().cpu().numpy())
    t_kappa.append(kappa_train.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # Make sure model is in eval mode
    model.eval()

    # Loop over validation batches
    with torch.no_grad():
        # Initialize running validation loss
        running_loss_v = 0.0
        for batch_idx, (inputs, targets) in enumerate(valDL):
            # Get data and move to device
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)

            # Calculate validation loss
            loss_v = criterion(outputs, targets)

            # Update running loss with batch results
            running_loss_v += loss_v.item()

            # Calculate metrics
            acc_val = acc(outputs, targets)
            f1_val = f1(outputs, targets)
            kappa_val = kappa(outputs, targets)

    # Accumulate loss and metrics at end of validation epoch
    epoch_loss_v = running_loss_v / len(valDL)
    acc_val = acc.compute()
    f1_val = f1.compute()
    kappa_val = kappa.compute()

    # Print validation loss and metrics
    print(f'Validation Loss: {epoch_loss_v:.4f}, Validation Accuracy: {acc_val:.4f}, Validation F1: {f1_val:.4f}, Validation Kappa: {kappa_val:.4f}')

    # Append results
    v_loss.append(epoch_loss_v)
    v_acc.append(acc_val.detach().cpu().numpy())
    v_f1.append(f1_val.detach().cpu().numpy())
    v_kappa.append(kappa_val.detach().cpu().numpy())

    # Reset metrics
    acc.reset()
    f1.reset()
    kappa.reset()

    # Save model if validation F1-score improves
    f1_val_np = f1_val.detach().cpu().numpy()
    if f1_val_np > f1VMax:
        f1VMax = f1_val_np
        torch.save(model.state_dict(), os.path.join(save_folder, 'landcoverai_unet_model.pt'))
        print(f'Model saved for epoch {epoch}.')

# Save the training metrics to a CSV
results_df = pd.DataFrame({
    "epoch": eNum,
    "training_loss": t_loss,
    "training_accuracy": t_acc,
    "training_f1": t_f1,
    "training_kappa": t_kappa,
    "val_loss": v_loss,
    "val_accuracy": v_acc,
    "val_f1": v_f1,
    "val_kappa": v_kappa
})

results_df.to_csv(os.path.join(save_folder, "training_results.csv"), index=False)
print(f"Training completed. Results saved to {os.path.join(save_folder, 'training_results.csv')}")

10.可视化训练结果

python 复制代码
# Load the training results
results_df = pd.read_csv(os.path.join(save_folder, "training_results.csv"))

# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(results_df['epoch'], results_df['training_loss'], label='Training Loss')
plt.plot(results_df['epoch'], results_df['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

# Plot training and validation F1 scores
plt.subplot(1, 2, 2)
plt.plot(results_df['epoch'], results_df['training_f1'], label='Training F1')
plt.plot(results_df['epoch'], results_df['val_f1'], label='Validation F1')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('Training and Validation F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Print the epoch with the best validation F1 score
best_epoch = results_df.loc[results_df['val_f1'].idxmax()]
print(f"Best model was saved at epoch {int(best_epoch['epoch'])} with validation F1 score of {best_epoch['val_f1']:.4f}")

由于 GPU 限制,我只测试了 20 个 epoch 的训练,Maxwell 教授测试了 50 个 epoch,您可以在我介绍中分享的原始文章中查看结果。

挑战与经验

在整个项目中,我遇到了几个挑战: 1.数据预处理 :处理大型正射影像需要仔细的平铺和预处理。2.类别不平衡 :一些土地覆盖类型(如水)出现的频率低于其他类型,这需要适当的损失函数。3.模型调整:找到合适的学习率和其他超参数需要进行实验,需要高 GPU 使用率,这是昂贵的。

相关推荐
学算法的程霖13 分钟前
机器学习核心算法全解析:从基础到进阶的 18 大算法模型
人工智能·python·深度学习·算法·目标检测·机器学习·计算机视觉
网安加社区14 分钟前
网安加·百家讲坛 | 刘志诚:AI安全风险与未来展望
人工智能·安全
我感觉。37 分钟前
【深度学习—李宏毅教程笔记】各式各样的 Attention
人工智能·深度学习·attention·self-attention
骑着小黑马1 小时前
前端程序员自己的知识库,使用NodeJS+LLM搭建一个属于自己的知识库
前端·人工智能
TMT星球1 小时前
快手本地生活2024年GMV同增200%,“新线城市+AI”将成增长引擎
大数据·人工智能·生活
CV-杨帆1 小时前
论文阅读:2023 arxiv A Survey of Reinforcement Learning from Human Feedback
论文阅读·人工智能
CH3_CH2_CHO1 小时前
DAY08:【pytorch】模型容器
人工智能·pytorch·python
Tiger Z1 小时前
R 语言科研绘图 --- 饼状图-汇总
开发语言·人工智能·程序人生·r语言·贴图
边缘计算社区1 小时前
边缘计算与AI融合:技术创新与产业变革的交汇点
人工智能·边缘计算
量子位1 小时前
人形机器人半马冠军,为什么会选择全尺寸?
人工智能·openai