使用函数式API构建模型,使得模型可以处理多输入多输出。
1、查看tensorflow版本
python
import tensorflow as tf
print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())
2、fashion_mnist数据集分类模型
2.1 使用Sequential构建模型
python
from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Input
model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()
2.2 使用函数式API构建模型
python
from keras.layers import Flatten,Dense,Dropout
from keras import Input,Model
input = Input(shape=(28,28))
x = Flatten()(input)
x = Dense(units=256,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
x = Dense(units=64,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
output = Dense(units=10,kernel_initializer='normal',activation='softmax')(x)
model = Model(inputs=input, outputs=output)
model.summary()
可以看到两个模型的结构是一样的,编译和训练也是一样的。
3、使用函数式API搭建多输入多输出模型
两个输入一个输出,对比两个图片是否一样。
python
from keras.layers import Flatten,Dense,Dropout
from keras import Input,Model
import keras
input1 = Input(shape=(28,28))
input2 = Input(shape=(28,28))
x1 = Flatten()(input1)
x2 = Flatten()(input2)
x = keras.layers.concatenate([x1,x2])
x = Dense(units=256,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
x = Dense(units=64,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
output = Dense(units=1,kernel_initializer='normal',activation='sigmoid')(x)
model = Model(inputs=[input1,input2], outputs=output) # 两个输入,一个输出
model.summary()