神经网络基础-神经网络补充概念-30-搭建神经网络块

概念

搭建神经网络块是一种常见的做法,它可以帮助你更好地组织和复用网络结构。神经网络块可以是一些相对独立的模块,例如卷积块、全连接块等,用于构建更复杂的网络架构。

代码实现

python 复制代码
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 定义一个卷积块
def convolutional_block(x, num_filters, kernel_size, pool_size):
    x = layers.Conv2D(num_filters, kernel_size, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(pool_size)(x)
    return x

# 构建神经网络模型
def build_model():
    inputs = layers.Input(shape=(28, 28, 1))  # 输入数据为28x28的灰度图像
    x = convolutional_block(inputs, num_filters=32, kernel_size=(3, 3), pool_size=(2, 2))
    x = convolutional_block(x, num_filters=64, kernel_size=(3, 3), pool_size=(2, 2))
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(10, activation='softmax')(x)  # 输出层,10个类别
    model = keras.Model(inputs, outputs)
    return model

# 加载数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype('float32') / 255.0
x_test = np.expand_dims(x_test, axis=-1).astype('float32') / 255.0
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)

# 构建模型
model = build_model()

# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.1)

# 评估模型
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_accuracy)
相关推荐
寻丶幽风1 小时前
论文阅读笔记——双流网络
论文阅读·笔记·深度学习·视频理解·双流网络
CM莫问3 小时前
<论文>(微软)避免推荐域外物品:基于LLM的受限生成式推荐
人工智能·算法·大模型·推荐算法·受限生成
康谋自动驾驶4 小时前
康谋分享 | 自动驾驶仿真进入“标准时代”:aiSim全面对接ASAM OpenX
人工智能·科技·算法·机器学习·自动驾驶·汽车
深蓝学院5 小时前
密西根大学新作——LightEMMA:自动驾驶中轻量级端到端多模态模型
人工智能·机器学习·自动驾驶
归去_来兮6 小时前
人工神经网络(ANN)模型
人工智能·机器学习·人工神经网络
2201_754918416 小时前
深入理解卷积神经网络:从基础原理到实战应用
人工智能·神经网络·cnn
强盛小灵通专卖员6 小时前
DL00219-基于深度学习的水稻病害检测系统含源码
人工智能·深度学习·水稻病害
Luke Ewin6 小时前
CentOS7.9部署FunASR实时语音识别接口 | 部署商用级别实时语音识别接口FunASR
人工智能·语音识别·实时语音识别·商用级别实时语音识别
白熊1886 小时前
【计算机视觉】OpenCV实战项目:Face-Mask-Detection 项目深度解析:基于深度学习的口罩检测系统
深度学习·opencv·计算机视觉
Joern-Lee6 小时前
初探机器学习与深度学习
人工智能·深度学习·机器学习