unet改进笔记

论文:U-Net: Convolutional Networks for Biomedical Image Segmentation

github:https://github.com/milesial/Pytorch-UNet

改进1:数据增强,新增augmentations.py

复制代码
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Image augmentation functions
"""

import math
import random

import cv2
import numpy as np
import imgaug.augmenters as iaa



class Albumentations:
    # YOLOv5 Albumentations class (optional, only used if package is installed)
    def __init__(self):
        self.transform = None
        try:
            #version 1.0.3
            import albumentations as A

            T = [
                #A.OneOf([
                    #A.IAAAdditiveGaussianNoise(),   # 将高斯噪声添加到输入图像
                    #A.GaussNoise(),    # 将高斯噪声应用于输入图像。
                #], p=0.2),   # 应用选定变换的概率
                A.OneOf([
                    A.MotionBlur(p=0.2),   # 使用随机大小的内核将运动模糊应用于输入图像。
                    A.MedianBlur(blur_limit=3, p=0.01),    # 中值滤波
                    A.Blur(blur_limit=3, p=0.01),   # 使用随机大小的内核模糊输入图像。
                ], p=0.2),
                # 随机应用仿射变换:平移,缩放和旋转输入
                A.RandomBrightnessContrast(p=0.2),   # 随机明亮对比度
                A.CLAHE(p=0.01),
                A.RandomGamma(p=0.0),
                A.ImageCompression(quality_lower=75, p=0.0)]  # transforms
            self.transform = A.Compose(T)

            print('albumentations: ' + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
        except ImportError:  # package not installed, skip
            pass
        except Exception as e:
            print('albumentations: '+ f'{e}')

    def __call__(self, im, p=0.8):
        if self.transform and random.random() < p:
            new = self.transform(image=im)  # transformed
            im = new['image']


        if random.random() > p:
            im = augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5)
        if random.random() > p:
            im = hist_equalize(im, clahe=True, bgr=True) 
            
            
            
            
        #数据增强

        range_num = random.randint(1, 2)
        for iii in range(range_num):
            rand_num = random.randint(0, 20)
            if rand_num == 0: 
                im = channel_shuffle(im)
            #elif rand_num == 1:
                #im = random_noise(im, limit=[0, 0.2], p=0.5)
            elif rand_num == 2:
                im = random_brightness(im, brightness=0.3)
            elif rand_num == 3:
                im = random_contrast(im,contrast=0.3)
            elif rand_num == 4:
                im = random_saturation(im, saturation=0.5)
            elif rand_num == 5:
                im = EqExtension(im)
            elif rand_num == 6:
                im = HueExtension(im)
            elif rand_num == 7:
                im = zaodian(im)
            elif rand_num == 8:
                im = superpixelsaug(im)
            elif rand_num == 9:
                im = fogaug(im)
            elif rand_num == 10:
                im = cloudsaug(im)
            #elif rand_num == 11:
                #im = fnaug(im)
            elif rand_num == 11:
                im = Coarseaug(im)
            elif rand_num == 12:
                im = random_hue(im)
            else :
                im = im
        ####aug
        #数据增强
            

        return im


def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
    # HSV color-space augmentation
    if hgain or sgain or vgain:
        r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
        hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
        dtype = im.dtype  # uint8

        x = np.arange(0, 256, dtype=r.dtype)
        lut_hue = ((x * r[0]) % 180).astype(dtype)
        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

        im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
        return cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR)  # no return needed


def hist_equalize(im, clahe=True, bgr=True):
    # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
    yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
    if clahe:
        c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        yuv[:, :, 0] = c.apply(yuv[:, :, 0])
    else:
        yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0])  # equalize Y channel histogram
    return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB)  # convert YUV image to RGB




def channel_shuffle(img):
    if(img.shape[2] == 3):
        ch_arr = [0, 1, 2]
        np.random.shuffle(ch_arr)
        img = img[..., ch_arr]
    return img




def EqExtension(src):
    I_backup = src.copy()
    b, g, r = cv2.split(I_backup)
    b = cv2.equalizeHist(b)
    g = cv2.equalizeHist(g)
    r = cv2.equalizeHist(r)
    I_eq = cv2.merge([b, g, r])
    return I_eq.astype(np.uint8)


def HueExtension(src):
    img_hsv = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)
    cre = random.randint(90, 100)
    cre = float(cre) / 100
    img_hsv[:,:,2] = img_hsv[:,:,2] * cre
 
    # print(img_hsv[:,:,0])
    dst = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)
    return dst.astype(np.uint8)




# 随机生成500个椒盐噪声
def zaodian(img):
    height,weight,channel = img.shape
    img_zao = img.copy()
    for i in range(500):
        x = np.random.randint(0,height)
        y = np.random.randint(0,weight)
        img_zao[x ,y ,:] = 255
    return img_zao.astype(np.uint8)


#def random_noise(img, limit=[0, 0.2], p=0.5):
    #if random.random() < p:
        #H, W = img.shape[:2]
        #noise = np.random.uniform(limit[0], limit[1], size=(H,W)) * 255

        #img = img + noise[:,:,np.newaxis]*np.array([1,1,1])
        #img = np.clip(img, 0, 255).astype(np.uint8)
        
    #return img

def random_brightness(image, brightness=0.3):
    alpha = 1 + np.random.uniform(-brightness, brightness)
    img = alpha * image
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def random_contrast(img,contrast=0.3):
    coef = np.array([[[0.114, 0.587,  0.299]]])   # rgb to gray (YCbCr)
    alpha = 1.0 + np.random.uniform(-contrast, contrast)
    gray = img * coef
    gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
    img = alpha*img  + gray
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def random_saturation(img, saturation=0.5):
    coef = np.array([[[0.299, 0.587, 0.114]]])
    alpha = np.random.uniform(-saturation, saturation)
    gray  = img * coef
    gray  = np.sum(gray,axis=2, keepdims=True)
    img = alpha*img  + (1.0 - alpha)*gray
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img

def random_hue(image,hue=0.5):
    h = int(np.random.uniform(-hue, hue)*180)

    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv[:, :, 0] = (hsv[:, :, 0].astype(int) + h) % 180
    image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    return image



def superpixelsaug(img):
    images = np.expand_dims(img, axis=0)
    aug = iaa.Superpixels(p_replace=0.005,max_size=50)
    images_aug = aug(images=images)
    img_aug = np.squeeze(images_aug)
    return img_aug.astype(np.uint8)



def fogaug(img):
    images = np.expand_dims(img, axis=0)
    aug = iaa.Fog()
    images_aug = aug(images=images)
    img_aug = np.squeeze(images_aug)
    return img_aug.astype(np.uint8)


def cloudsaug(img):
    images = np.expand_dims(img, axis=0)
    aug = iaa.Clouds()
    images_aug = aug(images=images)
    img_aug = np.squeeze(images_aug)
    return img_aug.astype(np.uint8)
 
#def fnaug(img):
    #images = np.expand_dims(img, axis=0)
    #aug = iaa.FrequencyNoiseAlpha(first=iaa.EdgeDetect(0.5))
    #images_aug = aug(images=images)
    #img_aug = np.squeeze(images_aug)
    #return img_aug.astype(np.uint8)


def Coarseaug(img):
    images = np.expand_dims(img, axis=0)
    aug = iaa.CoarseDropout(0.02, size_percent=0.5)
    images_aug = aug(images=images)
    img_aug = np.squeeze(images_aug)
    return img_aug.astype(np.uint8)

调用方式,在utils/data_loading.py中__getitem__函数增加

复制代码
#albumentations
            if self.albumentations is not None and self.mode == 'train':  
                image = self.albumentations(image)
            #albumentations

改进2:基于图片拼接的数据增强,需要修改图片和标签,在utils/data_loading.py中__getitem__函数增加

复制代码
def __getitem__(self, idx):
        #训练数据使用bgr格式
        max_height = 0
        max_width = 0
        height_up_left, width_up_left, height_up_right, width_up_right, height_down_left, width_down_left, height_down_right, width_down_right = 0,0,0,0,0,0,0,0
        num1, num2, num3, num4 = idx, random.randint(0,self.img_num-1),random.randint(0,self.img_num-1),random.randint(0,self.img_num-1)
        
        
        img_up_left = cv2.imread(self.img_list[num1], 1)
        #随机4个图片拼接为一个大图片
        if self.train:
            img_up_right  = cv2.imread(self.img_list[num2], 1)
            img_down_left  = cv2.imread(self.img_list[num3], 1)
            img_down_right  = cv2.imread(self.img_list[num4], 1)
        
            height_up_left, width_up_left = img_up_left.shape[:2]
            height_up_right, width_up_right = img_up_right.shape[:2]
            height_down_left, width_down_left = img_down_left.shape[:2]
            height_down_right, width_down_right = img_down_right.shape[:2]
            
            max_height = max([height_up_left, height_up_right, height_down_left, height_down_right])
            max_width = max([width_up_left, width_up_right, width_down_left, width_down_right])
            
            #假设训练数据大小不一样,类比fasterrcnn的方式,采用最大长宽的进行拼接,其余补黑边
            for num, im in enumerate([img_up_left, img_up_right, img_down_left, img_down_right]):
                img_template = np.zeros((max_height, max_width, 3),np.uint8)
                if num ==0: 
                    img_template[:height_up_left, :width_up_left, :] = im.copy()
                    img_up_left = img_template.copy()
                elif num==1:
                    img_template[:height_up_right, :width_up_right, :] = im.copy()
                    img_up_right = img_template.copy()
                elif num==2:
                    img_template[:height_down_left, :width_down_left, :] = im.copy()
                    img_down_left = img_template.copy()
                elif num==3:
                    img_template[:height_down_right, :width_down_right, :] = im.copy()
                    img_down_right = img_template.copy()
                else:
                    pass
                 
            img = np.vstack([np.hstack([img_up_left,img_up_right]), np.hstack([img_down_left,img_down_right])])
        else:
            img = img_up_left
        

        
       
        
        

        mask_up_left = cv2.imread(self.mask[num1], 0)
        if self.train:
            mask_up_right = cv2.imread(self.mask[num2], 0)
            mask_down_left = cv2.imread(self.mask[num3], 0)
            mask_down_right = cv2.imread(self.mask[num4], 0)
        
        
            #假设训练数据大小不一样,类比fasterrcnn的方式,采用最大长宽的进行拼接,其余补黑边
            for num, ma in enumerate([mask_up_left, mask_up_right, mask_down_left, mask_down_right]):
                mask_template = np.zeros((max_height, max_width), np.uint8)
                if num ==0: 
                    mask_template[:height_up_left, :width_up_left] = ma.copy()
                    mask_up_left = mask_template.copy()
                elif num==1:
                    mask_template[:height_up_right, :width_up_right] = ma.copy()
                    mask_up_right = mask_template.copy()
                elif num==2:
                    mask_template[:height_down_left, :width_down_left] = ma.copy()
                    mask_down_left = mask_template.copy()
                elif num==3:
                    mask_template[:height_down_right, :width_down_right] = ma.copy()
                    mask_down_right = mask_template.copy()
                else:
                    pass
        
            mask = np.vstack([np.hstack([mask_up_left,mask_up_right]), np.hstack([mask_down_left,mask_down_right])])
        else:
            mask = mask_up_left
        mask = mask/ 255.0  # 前景==255,背景==0
        
        
        if self.transforms is not None:
            #opencv to pillow
            img= Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
            mask = Image.fromarray(mask)
            img, mask = self.transforms(img, mask)
            
         

        return img, mask

改进3:训练图片分布的随机扰动,在utils/data_loading.py中__getitem__函数增加

复制代码
if self.train:
                #随机乘加操作
                img = img * random.randint(90,110)/100 + random.randint(-10,10)/100
                #随机乘加操作

改进4:增加预测推理代码,

复制代码
import os
import time
import cv2

import torch
from torchvision import transforms
import numpy as np
from PIL import Image

from src import UNet




class ThyroidNoduleDetect():
    def __init__(self, weights_path = "save_weights/N-net3_weight_best.pth"):
        #超声波甲状腺结节检测
        self.use_cuda = True
        self.use_float16 = False
        # mean = (0.302, 0.302, 0.302)
        # std = (0.093, 0.093, 0.093)
        self.mean = [0.326, 0.327, 0.354]
        self.std = [0.149, 0.149, 0.157]

        # get devices
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))

        # create model
        classes = 1  # exclude background
        self.model = UNet(in_channels=3, num_classes=classes+1, base_c=32)

        # load weights
        self.model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
        self.model.to(device)
        self.model.requires_grad_(False)
        self.model.eval()  # 进入验证模式

    def process(self, image):
    
        x = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        x = (x/np.max(x) - self.mean)/self.std
        x = np.expand_dims(x, axis=0)

        if self.use_cuda:
            x = torch.from_numpy(x).cuda()
        else:
            x = torch.from_numpy(x)
    
        x = x.to(torch.float32 if not self.use_float16 else torch.float16).permute(0, 3, 1, 2)

        t1 = time.time()
        with torch.no_grad():
            output = self.model(x)
            
            prediction = output['out'].argmax(1).squeeze(0)
            prediction = prediction.to("cpu").numpy().astype(np.uint8)
            prediction_mask = prediction*255
            
         

        t2 = time.time()
        tact_time = (t2 - t1)
        print("ThyroidNoduleDetect ",f'{ tact_time} seconds, {1 / tact_time} FPS, @batch_size 1')
        return prediction_mask


    def draw_circles(self, image, prediction_mask):
        _,binary = cv2.threshold(prediction_mask,127,255,cv2.THRESH_BINARY) 
        contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 
        cv2.drawContours(image,contours,-1,(0,0,255),3)
        
        for index,c in enumerate(contours):
            M = cv2.moments(c)
            print(index, M)
            if M["m00"]!=0.0:
                #计算重心
                cX = int(M["m10"] / M["m00"])
                cY = int(M["m01"] / M["m00"])
                cv2.circle(image, (cX, cY), 4, (0, 255, 0), -1)

        return image

def test_image():
    img_path = "TNUI/test/images/1079.png"
    roi_mask_path = "TNUI/test/masks/1079.png"
    image =cv2.imread(img_path,1)
    
    tnd = ThyroidNoduleDetect()
    prediction_mask = tnd.process(image)
    cv2.imwrite("pre_result/N-net2_test_1079.png", prediction_mask)


def eval_images():
    tnd = ThyroidNoduleDetect()
    image_path = "TNUI/test/images/"
    mask_path = "TNUI/test/masks/"

    for name in  os.listdir(image_path):
        image = cv2.imread(os.path.join(image_path,name),1)
        mask = cv2.imread(os.path.join(mask_path,name),1)
        
        prediction_mask = tnd.process(image)
        prediction_mask_bgr = cv2.cvtColor(prediction_mask,cv2.COLOR_GRAY2BGR)
        image = tnd.draw_circles(image, prediction_mask)
        
        cv2.imwrite(os.path.join("pre_result/",name), np.hstack([image, prediction_mask_bgr, mask]))




if __name__ == '__main__':
    test_image()
    #eval_images()

改进5:MobileV3Unet,新增mobilenet_unet.py

复制代码
from collections import OrderedDict
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.models import mobilenet_v3_large
from .unet import Up, OutConv


class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Args:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新构建backbone,将没有使用到的模块全部删掉
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out


class MobileV3Unet(nn.Module):
    def __init__(self, num_classes, pretrain_backbone: bool = False):
        super(MobileV3Unet, self).__init__()
        backbone = mobilenet_v3_large(pretrained=pretrain_backbone)

        # if pretrain_backbone:
        #     # 载入mobilenetv3 large backbone预训练权重
        #     # https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth
        #     backbone.load_state_dict(torch.load("mobilenet_v3_large.pth", map_location='cpu'))

        backbone = backbone.features

        stage_indices = [1, 3, 6, 12, 15]
        self.stage_out_channels = [backbone[i].out_channels for i in stage_indices]
        return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
        self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

        c = self.stage_out_channels[4] + self.stage_out_channels[3]
        self.up1 = Up(c, self.stage_out_channels[3])
        c = self.stage_out_channels[3] + self.stage_out_channels[2]
        self.up2 = Up(c, self.stage_out_channels[2])
        c = self.stage_out_channels[2] + self.stage_out_channels[1]
        self.up3 = Up(c, self.stage_out_channels[1])
        c = self.stage_out_channels[1] + self.stage_out_channels[0]
        self.up4 = Up(c, self.stage_out_channels[0])
        self.conv = OutConv(self.stage_out_channels[0], num_classes=num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        input_shape = x.shape[-2:]
        backbone_out = self.backbone(x)
        x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
        x = self.up2(x, backbone_out['stage2'])
        x = self.up3(x, backbone_out['stage1'])
        x = self.up4(x, backbone_out['stage0'])
        x = self.conv(x)
        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)

        return {"out": x}

改进6:VGG16UNet,新增vgg_unet.py

复制代码
from collections import OrderedDict
from typing import Dict

import torch
import torch.nn as nn
from torch import Tensor
from torchvision.models import vgg16_bn
from .unet import Up, OutConv


class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Args:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新构建backbone,将没有使用到的模块全部删掉
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out


class VGG16UNet(nn.Module):
    def __init__(self, num_classes, pretrain_backbone: bool = False):
        super(VGG16UNet, self).__init__()
        backbone = vgg16_bn(pretrained=pretrain_backbone)

        # if pretrain_backbone:
        #     # 载入vgg16_bn预训练权重
        #     # https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
        #     backbone.load_state_dict(torch.load("vgg16_bn.pth", map_location='cpu'))

        backbone = backbone.features

        stage_indices = [5, 12, 22, 32, 42]
        self.stage_out_channels = [64, 128, 256, 512, 512]
        return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
        self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

        c = self.stage_out_channels[4] + self.stage_out_channels[3]
        self.up1 = Up(c, self.stage_out_channels[3])
        c = self.stage_out_channels[3] + self.stage_out_channels[2]
        self.up2 = Up(c, self.stage_out_channels[2])
        c = self.stage_out_channels[2] + self.stage_out_channels[1]
        self.up3 = Up(c, self.stage_out_channels[1])
        c = self.stage_out_channels[1] + self.stage_out_channels[0]
        self.up4 = Up(c, self.stage_out_channels[0])
        self.conv = OutConv(self.stage_out_channels[0], num_classes=num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        backbone_out = self.backbone(x)
        x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
        x = self.up2(x, backbone_out['stage2'])
        x = self.up3(x, backbone_out['stage1'])
        x = self.up4(x, backbone_out['stage0'])
        x = self.conv(x)

        return {"out": x}

改进7:HNet,新增HNet.py

复制代码
from functools import partial
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

nonlinearity = partial(F.relu, inplace=True)


class SDCBlock(nn.Module):
    def __init__(self, channel):
        super(SDCBlock, self).__init__()
        self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
        self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3)
        self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=11, padding=11)
        self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)

    def forward(self, x):
        dilate1_out = nonlinearity(self.dilate1(x))
        dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x)))
        dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x))))
        dilate4_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(self.dilate3(x)))))
        out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out
        return out


class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )


class SingleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(SingleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
        )


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )

class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),
            DoubleConv(in_channels, out_channels)
        )


class DownLast(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(DownLast, self).__init__(
            nn.MaxPool2d(2, stride=2),
            SingleConv(in_channels, out_channels)
        )

class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = DownLast(base_c * 8, base_c * 16 // factor)
        # self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.sdc = SDCBlock(base_c * 8)
        self.singleconv = SingleConv(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.singleconv(self.sdc(self.down4(x4)))
        # x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)

        return {"out": logits}
相关推荐
数智顾问23 分钟前
【73页PPT】美的简单高效的管理逻辑(附下载方式)
大数据·人工智能·产品运营
love530love25 分钟前
【保姆级教程】阿里 Wan2.1-T2V-14B 模型本地部署全流程:从环境配置到视频生成(附避坑指南)
人工智能·windows·python·开源·大模型·github·音视频
木头左28 分钟前
结合机器学习的Backtrader跨市场交易策略研究
人工智能·机器学习·kotlin
Coovally AI模型快速验证34 分钟前
3D目标跟踪重磅突破!TrackAny3D实现「类别无关」统一建模,多项SOTA达成!
人工智能·yolo·机器学习·3d·目标跟踪·无人机·cocos2d
研梦非凡38 分钟前
CVPR 2025|基于粗略边界框监督的3D实例分割
人工智能·计算机网络·计算机视觉·3d
MiaoChuAI44 分钟前
秒出PPT vs 豆包AI PPT:实测哪款更好用?
人工智能·powerpoint
fsnine1 小时前
深度学习——残差神经网路
人工智能·深度学习
乖女子@@@1 小时前
React笔记_组件之间进行数据传递
javascript·笔记·react.js
和鲸社区2 小时前
《斯坦福CS336》作业1开源,从0手搓大模型|代码复现+免环境配置
人工智能·python·深度学习·计算机视觉·语言模型·自然语言处理·nlp
fanstuck2 小时前
2025 年高教社杯全国大学生数学建模竞赛C 题 NIPT 的时点选择与胎儿的异常判定详解(一)
人工智能·目标检测·数学建模·数据挖掘·aigc