Keras入门与残差网络的搭建

发现草稿箱里还有一篇很早之前的学习笔记,希望可以帮助到有需要的童鞋~

目录

1、keras入门

[2、残差网络 (ResNet)](#2、残差网络 (ResNet))

2.1、恒等块

2.2、卷积块

搭建一个50层的残差网络

自己的测试数据


1、keras入门

本文参考参考

Keras模型大纲:

python 复制代码
def model(input_shape):
    """
    模型大纲
    """
    #定义一个tensor的placeholder,维度为input_shape
    X_input = Input(input_shape)

    #使用0填充:X_input的周围填充0
    X = ZeroPadding2D((3,3))(X_input)

    # 对X使用 CONV -> BN -> RELU 块
    #第一个参数:输出的特征个数,第二个kernel_size,第三个stride
    X = Conv2D(32, (7, 7), strides = (1, 1), name = 'conv0')(X)
    X = BatchNormalization(axis = 3, name = 'bn0')(X)
    X = Activation('relu')(X)

    #最大值池化层
    X = MaxPooling2D((2,2),name="max_pool")(X)

    #降维,矩阵转化为向量 + 全连接层
    X = Flatten()(X)
    #全连接,第一个参数:输出特征个数,第二个参数:激活方式
    X = Dense(1, activation='sigmoid', name='fc')(X)

    #创建模型,讲话创建一个模型的实体,我们可以用它来训练、测试。
    #类似tensorflow,你告诉他输入和输出之间的关系就可以了
    model = Model(inputs = X_input, outputs = X, name='HappyModel')

    return model

设计好模型之后:

1、创建模型实体

2、编译模型:参数顺序为 1、优化器 2、损失计算方式 3、衡量指标 (编译模型就是告诉他实施的具体细节,否则样本输入之后模型也不知道如何计算如何优化)

3、训练模型 :输入训练样本以及标签,迭代次数,批数据大小

4、评估模型:

python 复制代码
#创建一个模型实体
model_test = model(X_train.shape[1:])
#编译模型
model_test.compile("adam","binary_crossentropy", metrics=['accuracy'])
#训练模型
#请注意,此操作会花费你大约6-10分钟。
model_test.fit(X_train, Y_train, epochs=40, batch_size=50)
#评估模型
preds = model_test.evaluate(X_test, Y_test, batch_size=32, verbose=1, sample_weight=None)
print ("误差值 = " + str(preds[0]))
print ("准确度 = " + str(preds[1]))

其他功能:

1、model.summary():打印每一层的细节,输出类似下面的结果

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 64, 64, 3)         0         
_________________________________________________________________
zero_padding2d_2 (ZeroPaddin (None, 70, 70, 3)         0         
_________________________________________________________________
conv0 (Conv2D)               (None, 64, 64, 32)        4736      
_________________________________________________________________
bn0 (BatchNormalization)     (None, 64, 64, 32)        128       
_________________________________________________________________
activation_2 (Activation)    (None, 64, 64, 32)        0         
_________________________________________________________________
max_pool (MaxPooling2D)      (None, 32, 32, 32)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 32768)             0         
_________________________________________________________________
fc (Dense)                   (None, 1)                 32769     
=================================================================
Total params: 37,633
Trainable params: 37,569
Non-trainable params: 64
_________________________________________________________________

2、plot_model():绘制布局图

python 复制代码
%matplotlib inline
plot_model(happy_model, to_file='happy_model.png')
SVG(model_to_dot(happy_model).create(prog='dot', format='svg'))

2、残差网络 (ResNet)

神经网络层数深了会变得更加难以训练,出现梯度消失等问题。残差网络解决了深层网络难以训练的问题。

2.1、恒等块

基本结构:上面的曲线为捷径,可以看到在输入X卷积二次之后输出的结果和输入X进行了相加,然后进行了激活。这样做就实现了更深层次的梯度直接传向较浅的层的功能。

实现细节:由于需要相加,那么两次卷积的输出结果需要和输入X的shape相同,所以这就被称为恒等块。下面的实现中将会完成下图的3层跳跃,同样这也是一个恒等块。

python 复制代码
import numpy as np
import tensorflow as tf
from keras import layers
from keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
from keras.models import Model, load_model
from keras.preprocessing import image
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import preprocess_input
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from keras.initializers import glorot_uniform

import pydot
from IPython.display import SVG
import scipy.misc
from matplotlib.pyplot import imshow
import keras.backend as K
K.set_image_data_format('channels_last')
K.set_learning_phase(1)

import resnets_utils 

不得不说,Keras牛批。

Conv2D(输出特征数,kernel_size, stride, padding, name, kernel_initializer)():直接完成了卷积操作,一步到位

BatchNormalization(axis, name):对通道层批量归一化,axis = 3

Activation(): 完成激活

python 复制代码
def identify_block(X,f,filters,stage,block):
    """
    X - 输入的tensor类型数据,维度为(m, n_H_prev, n_W_prev, n_H_prev)
    f - kernal大小
    filters - 整数列表,定义每一层卷积层过滤器的数量
    stage - 整数 定义层位置 
    block - 字符串 定义层位置 
    
    X - 恒等输出,tensor类型,维度(n_H, n_W, n_C)
    """
    conv_name_base = 'res' +str(stage) +block +'_branch'
    bn_name_base = 'bn' + str(stage) + block +  '_branch'
    F1, F2, F3 = filters  #定义输出特征的个数
    X_shortcut = X 
    
    X = Conv2D(filters=F1,kernel_size=(1,1),strides=(1,1),padding='valid',name = conv_name_base+'2a',
               kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name= bn_name_base+'2a')(X)
    X = Activation('relu')(X)
    
    
    X = Conv2D(filters=F2, kernel_size=(f,f),strides=(1,1),padding='same',name=conv_name_base+'2b',
              kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3,name= bn_name_base+'2b')(X)
    X = Activation('relu')(X)
    
    
    X = Conv2D(filters=F3,kernel_size=(1,1),strides=(1,1),padding='valid',name=conv_name_base+'2c',
              kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3,name=bn_name_base+'2c')(X)
    #没有激活
    
    
    X = Add()([X,X_shortcut])
    X = Activation('relu')(X)
    return X

2.2、卷积块

上述恒等块要求在主线上进行卷积时shape不变,这样才能和捷径上的X相加。如果形状变化了,那就在捷径中加上卷积层,使捷径上卷积层的输出和主线上的shape相同。

python 复制代码
def convolutional_block(X,f,filters,stage,block,s=2):
    #参数意义和上文相同
    conv_name_base = 'res' +str(stage) +block +'_branch'
    bn_name_base = 'bn' + str(stage) + block +  '_branch'
    
    F1,F2,F3 = filters
    X_shortcut = X 
    
    X = Conv2D(filters=F1,kernel_size=(1,1),strides=(s,s),padding='valid',name = conv_name_base+'2a',kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3,name=bn_name_base+'2a')(X)
    X = Activation('relu')(X)

    X = Conv2D(filters=F2,kernel_size=(f,f),strides=(1,1),padding='same',name = conv_name_base+'2b',kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3,name= bn_name_base+'2b')(X)
    X =Activation('relu')(X)

    X = Conv2D(filters=F3,kernel_size=(1,1),strides=(1,1),padding='valid',name=conv_name_base+'2c',kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name= bn_name_base+'2c')(X)
    

    
    #shortcut
    X_shortcut = Conv2D(filters=F3,kernel_size=(1,1),strides=(s,s),padding='valid',name = conv_name_base+'1',
                       kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
    X_shortcut = BatchNormalization(axis=3,name = bn_name_base+'1')(X_shortcut)
    

    X = Add()([X,X_shortcut])
    X = Activation('relu')(X)
    
    return X 

搭建一个50层的残差网络

网络结构如下:

ID_BLOCK对应恒等块,CONV_BLOCK对应卷积块,每个块有3层,总共50层。

python 复制代码
def ResNet50(input_shape=(64,64,3),classes=6):
    """
    CONV2D -> BATCHNORM -> RELU -> MAXPOOL -> CONVBLOCK -> IDBLOCK*2 -> CONVBLOCK -> IDBLOCK*3
    -> CONVBLOCK -> IDBLOCK*5 -> CONVBLOCK -> IDBLOCK*2 -> AVGPOOL -> TOPLAYER
    
    input_shape: 据集维度
    classes: 分类数 
    """
    #定义一个placeholder
    X_input = Input(input_shape)
    #0填充
    X = ZeroPadding2D((3,3))(X_input)
    
    #stage1
    X = Conv2D(filters=64,kernel_size=(7,7),strides=(2,2),name='conv1',kernel_initializer=glorot_uniform(seed=0))(X)
    X= BatchNormalization(axis=3, name='bn_conv1')(X)
    X = Activation('relu')(X)
    X = MaxPooling2D(pool_size=(3,3),strides=(2,2))(X)
    
    #stage2
    X = convolutional_block(X,f=3,filters=[64,64,256],stage=2,block='a',s=1)
    X = identify_block(X, f=3,filters=[64,64,256],stage=2,block='b')
    X = identify_block(X, f=3,filters=[64,64,256],stage=2,block='c')
    
    #stage3
    X = convolutional_block(X, f=3, filters=[128,128,512], stage=3, block="a", s=2)
    X = identify_block(X, f=3, filters=[128,128,512], stage=3, block="b")
    X = identify_block(X, f=3, filters=[128,128,512], stage=3, block="c")
    X = identify_block(X, f=3, filters=[128,128,512], stage=3, block="d")

    #stage4
    X = convolutional_block(X, f=3, filters=[256,256,1024], stage=4, block="a", s=2)
    X = identify_block(X, f=3, filters=[256,256,1024], stage=4, block="b")
    X = identify_block(X, f=3, filters=[256,256,1024], stage=4, block="c")
    X = identify_block(X, f=3, filters=[256,256,1024], stage=4, block="d")
    X = identify_block(X, f=3, filters=[256,256,1024], stage=4, block="e")
    X = identify_block(X, f=3, filters=[256,256,1024], stage=4, block="f")

    #stage5
    X = convolutional_block(X, f=3, filters=[512,512,2048], stage=5, block="a", s=2)
    X = identify_block(X, f=3, filters=[512,512,2048], stage=5, block="b")
    X = identify_block(X, f=3, filters=[512,512,2048], stage=5, block="c")

    #均值池化
    X = AveragePooling2D(pool_size=(2,2),padding='same')(X)
    
    
    #输出层
    X = Flatten()(X)
    X = Dense(classes,activation="softmax",name='fc'+str(classes),kernel_initializer=glorot_uniform(seed=0))(X)
    
    model = Model(inputs = X_input,output = X, name= 'ResNet50')
    return model

创建实例以及编译 ,训练。我们要做的就是输入数据的shape

python 复制代码
model = ResNet50(input_shape=(64,64,3),classes=6)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model.fit(X_train,Y_train,epochs=2,batch_size=32)

模型评估

python 复制代码
preds = model.evaluate(X_test,Y_test)

print("误差值 = " + str(preds[0]))
print("准确率 = " + str(preds[1]))
python 复制代码
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 70, 70, 3)    0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 32, 32, 64)   9472        zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 32, 32, 64)   256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_66 (Activation)      (None, 32, 32, 64)   0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 15, 15, 64)   0           activation_66[0][0]              
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 15, 15, 64)   4160        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 15, 15, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_67 (Activation)      (None, 15, 15, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 15, 15, 64)   36928       activation_67[0][0]              
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 15, 15, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_68 (Activation)      (None, 15, 15, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 15, 15, 256)  16640       activation_68[0][0]              
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 15, 15, 256)  16640       max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 15, 15, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 15, 15, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_22 (Add)                    (None, 15, 15, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_69 (Activation)      (None, 15, 15, 256)  0           add_22[0][0]                     
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 15, 15, 64)   16448       activation_69[0][0]              
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 15, 15, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_70 (Activation)      (None, 15, 15, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 15, 15, 64)   36928       activation_70[0][0]              
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 15, 15, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_71 (Activation)      (None, 15, 15, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 15, 15, 256)  16640       activation_71[0][0]              
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 15, 15, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_23 (Add)                    (None, 15, 15, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_69[0][0]              
__________________________________________________________________________________________________
activation_72 (Activation)      (None, 15, 15, 256)  0           add_23[0][0]                     
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 15, 15, 64)   16448       activation_72[0][0]              
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 15, 15, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_73 (Activation)      (None, 15, 15, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 15, 15, 64)   36928       activation_73[0][0]              
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 15, 15, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_74 (Activation)      (None, 15, 15, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 15, 15, 256)  16640       activation_74[0][0]              
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 15, 15, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_24 (Add)                    (None, 15, 15, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_72[0][0]              
__________________________________________________________________________________________________
activation_75 (Activation)      (None, 15, 15, 256)  0           add_24[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 8, 8, 128)    32896       activation_75[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 8, 8, 128)    512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_76 (Activation)      (None, 8, 8, 128)    0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 8, 8, 128)    147584      activation_76[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 8, 8, 128)    512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 8, 8, 128)    0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 8, 8, 512)    66048       activation_77[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 8, 8, 512)    131584      activation_75[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 8, 8, 512)    2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 8, 8, 512)    2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_25 (Add)                    (None, 8, 8, 512)    0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_78 (Activation)      (None, 8, 8, 512)    0           add_25[0][0]                     
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 8, 8, 128)    65664       activation_78[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 8, 8, 128)    512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_79 (Activation)      (None, 8, 8, 128)    0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 8, 8, 128)    147584      activation_79[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 8, 8, 128)    512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_80 (Activation)      (None, 8, 8, 128)    0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 8, 8, 512)    66048       activation_80[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 8, 8, 512)    2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_26 (Add)                    (None, 8, 8, 512)    0           bn3b_branch2c[0][0]              
                                                                 activation_78[0][0]              
__________________________________________________________________________________________________
activation_81 (Activation)      (None, 8, 8, 512)    0           add_26[0][0]                     
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 8, 8, 128)    65664       activation_81[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 8, 8, 128)    512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_82 (Activation)      (None, 8, 8, 128)    0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 8, 8, 128)    147584      activation_82[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 8, 8, 128)    512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_83 (Activation)      (None, 8, 8, 128)    0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 8, 8, 512)    66048       activation_83[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 8, 8, 512)    2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_27 (Add)                    (None, 8, 8, 512)    0           bn3c_branch2c[0][0]              
                                                                 activation_81[0][0]              
__________________________________________________________________________________________________
activation_84 (Activation)      (None, 8, 8, 512)    0           add_27[0][0]                     
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 8, 8, 128)    65664       activation_84[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 8, 8, 128)    512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_85 (Activation)      (None, 8, 8, 128)    0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 8, 8, 128)    147584      activation_85[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 8, 8, 128)    512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_86 (Activation)      (None, 8, 8, 128)    0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 8, 8, 512)    66048       activation_86[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 8, 8, 512)    2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_28 (Add)                    (None, 8, 8, 512)    0           bn3d_branch2c[0][0]              
                                                                 activation_84[0][0]              
__________________________________________________________________________________________________
activation_87 (Activation)      (None, 8, 8, 512)    0           add_28[0][0]                     
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 4, 4, 256)    131328      activation_87[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_88 (Activation)      (None, 4, 4, 256)    0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_88[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_89 (Activation)      (None, 4, 4, 256)    0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_89[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 4, 4, 1024)   525312      activation_87[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 4, 4, 1024)   4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_29 (Add)                    (None, 4, 4, 1024)   0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_90 (Activation)      (None, 4, 4, 1024)   0           add_29[0][0]                     
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 4, 4, 256)    262400      activation_90[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_91 (Activation)      (None, 4, 4, 256)    0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_91[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_92 (Activation)      (None, 4, 4, 256)    0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_92[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_30 (Add)                    (None, 4, 4, 1024)   0           bn4b_branch2c[0][0]              
                                                                 activation_90[0][0]              
__________________________________________________________________________________________________
activation_93 (Activation)      (None, 4, 4, 1024)   0           add_30[0][0]                     
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 4, 4, 256)    262400      activation_93[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_94 (Activation)      (None, 4, 4, 256)    0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_94[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_95 (Activation)      (None, 4, 4, 256)    0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_95[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_31 (Add)                    (None, 4, 4, 1024)   0           bn4c_branch2c[0][0]              
                                                                 activation_93[0][0]              
__________________________________________________________________________________________________
activation_96 (Activation)      (None, 4, 4, 1024)   0           add_31[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 4, 4, 256)    262400      activation_96[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_97 (Activation)      (None, 4, 4, 256)    0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_97[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_98 (Activation)      (None, 4, 4, 256)    0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_98[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_32 (Add)                    (None, 4, 4, 1024)   0           bn4d_branch2c[0][0]              
                                                                 activation_96[0][0]              
__________________________________________________________________________________________________
activation_99 (Activation)      (None, 4, 4, 1024)   0           add_32[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 4, 4, 256)    262400      activation_99[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_100 (Activation)     (None, 4, 4, 256)    0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_100[0][0]             
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_101 (Activation)     (None, 4, 4, 256)    0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_101[0][0]             
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_33 (Add)                    (None, 4, 4, 1024)   0           bn4e_branch2c[0][0]              
                                                                 activation_99[0][0]              
__________________________________________________________________________________________________
activation_102 (Activation)     (None, 4, 4, 1024)   0           add_33[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 4, 4, 256)    262400      activation_102[0][0]             
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 4, 4, 256)    1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_103 (Activation)     (None, 4, 4, 256)    0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 4, 4, 256)    590080      activation_103[0][0]             
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 4, 4, 256)    1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_104 (Activation)     (None, 4, 4, 256)    0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 4, 4, 1024)   263168      activation_104[0][0]             
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 4, 4, 1024)   4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_34 (Add)                    (None, 4, 4, 1024)   0           bn4f_branch2c[0][0]              
                                                                 activation_102[0][0]             
__________________________________________________________________________________________________
activation_105 (Activation)     (None, 4, 4, 1024)   0           add_34[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 2, 2, 512)    524800      activation_105[0][0]             
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 2, 2, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_106 (Activation)     (None, 2, 2, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 2, 2, 512)    2359808     activation_106[0][0]             
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 2, 2, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_107 (Activation)     (None, 2, 2, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 2, 2, 2048)   1050624     activation_107[0][0]             
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 2, 2, 2048)   2099200     activation_105[0][0]             
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 2, 2, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 2, 2, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_35 (Add)                    (None, 2, 2, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_108 (Activation)     (None, 2, 2, 2048)   0           add_35[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 2, 2, 512)    1049088     activation_108[0][0]             
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 2, 2, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_109 (Activation)     (None, 2, 2, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 2, 2, 512)    2359808     activation_109[0][0]             
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 2, 2, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_110 (Activation)     (None, 2, 2, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 2, 2, 2048)   1050624     activation_110[0][0]             
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 2, 2, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_36 (Add)                    (None, 2, 2, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_108[0][0]             
__________________________________________________________________________________________________
activation_111 (Activation)     (None, 2, 2, 2048)   0           add_36[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 2, 2, 512)    1049088     activation_111[0][0]             
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 2, 2, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_112 (Activation)     (None, 2, 2, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 2, 2, 512)    2359808     activation_112[0][0]             
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 2, 2, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_113 (Activation)     (None, 2, 2, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 2, 2, 2048)   1050624     activation_113[0][0]             
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 2, 2, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_37 (Add)                    (None, 2, 2, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_111[0][0]             
__________________________________________________________________________________________________
activation_114 (Activation)     (None, 2, 2, 2048)   0           add_37[0][0]                     
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 1, 1, 2048)   0           activation_114[0][0]             
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 2048)         0           average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
fc6 (Dense)                     (None, 6)            12294       flatten_2[0][0]                  
==================================================================================================
Total params: 23,600,006
Trainable params: 23,546,886
Non-trainable params: 53,120

python 复制代码
plot_model(model, to_file='model.png')
SVG(model_to_dot(model).create(prog='dot', format='svg'))

自己的测试数据

这里放上一个基本流程

python 复制代码
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt # plt 用于显示图片

%matplotlib inline

img_path = 'images/fingers_big/2.jpg'

my_image = image.load_img(img_path, target_size=(64, 64))
my_image = image.img_to_array(my_image)

my_image = np.expand_dims(my_image,axis=0)
my_image = preprocess_input(my_image)

print("my_image.shape = " + str(my_image.shape))

print("class prediction vector [p(0), p(1), p(2), p(3), p(4), p(5)] = ")
print(model.predict(my_image))

my_image = scipy.misc.imread(img_path)
plt.imshow(my_image)
相关推荐
IT古董42 分钟前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee43 分钟前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa43 分钟前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐1 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
蓝天星空1 小时前
Python调用open ai接口
人工智能·python
睡觉狂魔er1 小时前
自动驾驶控制与规划——Project 3: LQR车辆横向控制
人工智能·机器学习·自动驾驶
fantasy_arch1 小时前
CPU性能优化-磁盘空间和解析时间
网络·性能优化
scan7241 小时前
LILAC采样算法
人工智能·算法·机器学习
leaf_leaves_leaf1 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零11 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志