使用Tensorflow2.x复现uxnet3D

models.py

python 复制代码
from nets.model_layers import UXNETBlock, DownSampleBlock, ResBlock3D
from tensorflow import keras

def uxnet3D(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # out1
    embeddings = keras.layers.Conv3D(48,3,2,padding="same", name='embedding')(inputs)

    # out2
    x = UXNETBlock("uxblock_1")(embeddings)
    x = DownSampleBlock("down_1")(x)
    out2 = x
    
    # out3
    x = UXNETBlock("uxblock_2")(x)
    x = DownSampleBlock("down_2")(x)
    out3 = x

    # out4
    x = UXNETBlock("uxblock_3")(x)
    x = DownSampleBlock("down_3")(x)
    out4 = x

    # out5
    x = UXNETBlock("uxblock_4")(x)
    x = DownSampleBlock("down_4")(x)
    out5 = x

    out5 = ResBlock3D(out5.shape[-1], "res_block_5")(out5)
    out5_up = keras.layers.UpSampling3D(name='out5_up')(out5)
    ####################################

    out4_dim = out4.shape[-1]
    out4 = ResBlock3D(out4_dim, "res_block_4_1")(out4)
    out4 = keras.layers.Concatenate(name="cat_out4_out5Up")([out4, out5_up])
    out4 = ResBlock3D(out4_dim, "res_block_4_2")(out4)
    out4_up = keras.layers.UpSampling3D(name='out4_up')(out4)
    ####################################

    out3_dim = out3.shape[-1]
    out3 = ResBlock3D(out3_dim, "res_block_3_1")(out3)
    out3 = keras.layers.Concatenate(name="cat_out3_out4Up")([out3, out4_up])
    out3 = ResBlock3D(out3_dim, "res_block_3_2")(out3)
    out3_up = keras.layers.UpSampling3D(name='out3_up')(out3)    
    ####################################

    out2_dim = out2.shape[-1]
    out2 = ResBlock3D(out2_dim, "res_block_2_1")(out2)
    out2 = keras.layers.Concatenate(name="cat_out2_out3Up")([out2, out3_up])
    out2 = ResBlock3D(out2_dim, "res_block_2_2")(out2)
    out2_up = keras.layers.UpSampling3D(name='out2_up')(out2)   
    ####################################

    out1_dim = embeddings.shape[-1]
    out1 = ResBlock3D(out1_dim, "res_block_1_1")(embeddings)
    out1 = keras.layers.Concatenate(name="cat_out1_out2Up")([out1, out2_up])
    out1 = ResBlock3D(out1_dim, "res_block_1_2")(out1)
    out1_up = keras.layers.UpSampling3D(name='out1_up')(out1)   
    ####################################

    out1 = ResBlock3D(inputs.shape[-1], "res_block_0")(inputs)
    out1 = keras.layers.Concatenate(name="cat_out0_out1Up")([out1, out1_up])
    ####################################

    outputs = keras.layers.Conv3D(num_classes, 1, activation='sigmoid', name='outputs')(out1)

    model = keras.Model(inputs, outputs, name='uxnet3D')

    return model



model = uxnet3D(input_shape=[128,128,32,1],num_classes=2)
model.summary()
# print(len(model.layers))
keras.utils.plot_model(model,to_file=f'{model.name}.png',show_shapes=True)

layers.py

python 复制代码
import tensorflow as tf
from tensorflow import keras


class UXNETHalfBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(UXNETHalfBlock,self).__init__(name=block_name)
        self.block_name = block_name

        self.layer_norm_1 = keras.layers.LayerNormalization()
        self.layer_norm_2 = keras.layers.LayerNormalization()

        self.add1 = keras.layers.Add()
        self.add2 = keras.layers.Add()

    def get_config(self):
        config = super(UXNETHalfBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def build(self, input_shape):
        num_filters = input_shape[-1]
        self.depth_wise_conv_1 = keras.layers.Conv3D(num_filters, 7, padding='same')
        self.depth_conv_scale_1 = keras.layers.Conv3D(int(num_filters * 4), 1, padding='same')
        self.depth_conv_scaleBack_1 = keras.layers.Conv3D(num_filters, 1, padding='same')

    
    def call(self, inputs):
        x = self.layer_norm_1(inputs)
        x = self.depth_wise_conv_1(x)

        x1 = self.add1([inputs, x])
        ################################

        x = self.layer_norm_2(x1)
        x = self.depth_conv_scale_1(x)
        x = tf.nn.gelu(x)
        x  = self.depth_conv_scaleBack_1(x)
        x = self.add2([x,x1])
        ################################

        return x


class UXNETBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(UXNETBlock, self).__init__(name=block_name)

        self.block_name = block_name

        self.block1 = UXNETHalfBlock(block_name=f"{block_name}_part1")
        self.block2 = UXNETHalfBlock(block_name=f"{block_name}_part2")

    def get_config(self):
        config = super(UXNETBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def call(self, inputs):
        x = self.block1(inputs)
        x = self.block2(x)

        return x

class DownSampleBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(DownSampleBlock, self).__init__(name=block_name)

        self.block_name = block_name

    def build(self, input_shape):
        num_filters = int(input_shape[-1] * 2)
        self.conv = keras.layers.Conv3D(num_filters, 3, 2, padding='same')
        self.norm = keras.layers.BatchNormalization()
    
    def get_config(self):
        config = super(DownSampleBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.norm(x)
        x = tf.nn.gelu(x)
        return x


class ResBlock3D(keras.layers.Layer): 
    def __init__(self, out_channels, block_name):
        super(ResBlock3D, self).__init__(name=block_name)

        self.add = keras.layers.Add()
        self.out_channels = out_channels
        self.block_name = block_name

    def get_config(self):
        config = super(ResBlock3D,self).get_config()
        config.update(
            {
                "block_name": self.block_name,
                "out_channels": self.out_channels
            }
        )
        return config

    def build(self, input_shape):
        num_filters = input_shape[-1]
        self.conv1 = keras.layers.Conv3D(num_filters, 1, 1, padding='same')
        self.norm1 = keras.layers.BatchNormalization()
        
        self.conv2 = keras.layers.Conv3D(num_filters, 3, 1, padding='same')
        self.norm2 = keras.layers.BatchNormalization()

        self.conv3 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
        self.norm3 = keras.layers.BatchNormalization()

        self.conv4 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
        self.norm4 = keras.layers.BatchNormalization()
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.norm1(x)
        x = tf.nn.gelu(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = tf.nn.gelu(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x1 = tf.nn.gelu(x)

        x2 = self.conv4(inputs)
        x2= self.norm4(x2)
        x2 = tf.nn.gelu(x2)

        return self.add([x1,x2])
    
相关推荐
charley.layabox3 小时前
8月1日ChinaJoy酒会 | 游戏出海高端私享局 | 平台 × 发行 × 投资 × 研发精英畅饮畅聊
人工智能·游戏
DFRobot智位机器人3 小时前
AIOT开发选型:行空板 K10 与 M10 适用场景与选型深度解析
人工智能
想成为风筝5 小时前
从零开始学习深度学习—水果分类之PyQt5App
人工智能·深度学习·计算机视觉·pyqt
F_D_Z5 小时前
MMaDA:多模态大型扩散语言模型
人工智能·语言模型·自然语言处理
江沉晚呤时5 小时前
在 C# 中调用 Python 脚本:实现跨语言功能集成
python·microsoft·c#·.net·.netcore·.net core
大知闲闲哟6 小时前
深度学习G2周:人脸图像生成(DCGAN)
人工智能·深度学习
飞哥数智坊6 小时前
Coze实战第15讲:钱都去哪儿了?Coze+飞书搭建自动记账系统
人工智能·coze
wenzhangli76 小时前
低代码引擎核心技术:OneCode常用动作事件速查手册及注解驱动开发详解
人工智能·低代码·云原生
电脑能手6 小时前
如何远程访问在WSL运行的Jupyter Notebook
ide·python·jupyter
Edward-tan7 小时前
CCPD 车牌数据集提取标注,并转为标准 YOLO 格式
python