JAX 来构建一个基本的人工神经网络(ANN)进行分类任务

python 复制代码
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.experimental import optimizers
from jax.nn import relu, softmax

# 构建神经网络模型
def neural_network(params, x):
    for W, b in params:
        x = jnp.dot(x, W) + b
        x = relu(x)
    return softmax(x)

# 初始化参数
def init_params(rng, layer_sizes):
    keys = random.split(rng, len(layer_sizes))
    return [(random.normal(k, (m, n)), random.normal(k, (n,))) 
            for k, (m, n) in zip(keys, zip(layer_sizes[:-1], layer_sizes[1:]))]

# 定义损失函数
def cross_entropy_loss(params, batch):
    inputs, targets = batch
    preds = neural_network(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))

# 初始化优化器
def init_optimizer(params):
    return optimizers.adam(init_params)

# 更新参数
@jit
def update(params, batch, opt_state):
    grads = grad(cross_entropy_loss)(params, batch)
    updates, opt_state = opt.update(grads, opt_state)
    return opt_params, opt_state

# 训练函数
def train(rng, params, data, num_epochs=10, batch_size=32):
    opt_init, opt_update, get_params = init_optimizer(params)
    opt_state = opt_init(params)
    
    num_batches = len(data) // batch_size
    
    for epoch in range(num_epochs):
        rng, subrng = random.split(rng)
        for batch_idx in range(num_batches):
            batch = get_batch(data, batch_idx, batch_size)
            params = update(params, batch, opt_state)
        
        train_loss = cross_entropy_loss(params, batch)
        print(f"Epoch {epoch+1}, Loss: {train_loss}")
    
    return get_params(opt_state)

# 评估函数
def evaluate(params, data):
    inputs, targets = data
    preds = neural_network(params, inputs)
    accuracy = jnp.mean(jnp.argmax(preds, axis=1) == jnp.argmax(targets, axis=1))
    return accuracy

# 示例数据集和参数
rng = random.PRNGKey(0)
input_size = 784
num_classes = 10
layer_sizes = [input_size, 128, num_classes]
params = init_params(rng, layer_sizes)
opt = init_optimizer(params)

# 使用数据集进行训练
trained_params = train(rng, params, data)

# 评估模型
accuracy = evaluate(trained_params, test_data)
print("Test Accuracy:", accuracy)

理解如何使用 JAX 或其他深度学习库构建人工智能(AI)系统需要一定的学习和实践。下面我给你一个简单的例子来说明如何使用 JAX 来构建一个基本的人工神经网络(ANN)进行分类任务。

首先,让我们假设你想解决一个简单的图像分类问题,例如手写数字识别。我们将使用一个基本的全连接神经网络来实现这个任务。

这只是一个简单的示例,用于说明如何使用 JAX 来构建神经网络进行图像分类任务。实际情况下,你可能需要更复杂的网络结构、更大规模的数据集以及更多的训练技巧来实现更好的性能。继续学习和实践将帮助你更好地理解如何构建 AI 系统。

要生成并存储模型文件,你可以使用 joblib 库,就像之前保存模型一样。以下是评估模型并保存模型的代码示例:

python 复制代码
python
import joblib

# 评估模型
accuracy = evaluate(trained_params, test_data)
print("Test Accuracy:", accuracy)

# 将训练好的模型保存为文件
joblib.dump(trained_params, 'trained_model.pkl')

此代码评估了训练好的模型在测试数据集上的准确率,并将模型保存为名为 trained_model.pkl 的文件。在此之后,你可以将 trained_model.pkl 文件用于部署模型或在其他地方进行预测。

让我们假设你已经训练了一个模型来识别手写数字。现在,我将展示如何结合手写图片应用并输出识别结果。我们将使用 Python 的 Flask 框架来构建一个简单的 Web 应用,并在用户上传手写数字图片后,使用训练好的模型进行预测。

首先,确保你已经安装了 Flask:

bash

bash 复制代码
pip install flask

然后,你可以创建一个名为 app.py 的 Python 脚本,其中包含以下内容:

python 复制代码
python
from flask import Flask, render_template, request
from PIL import Image
import numpy as np
import joblib

app = Flask(__name__)

# 加载训练好的模型
model = joblib.load('trained_model.pkl')

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    # 获取上传的图片文件
    file = request.files['file']
    
    # 将上传的图片转换为灰度图像并缩放为 28x28 像素
    img = Image.open(file).convert('L').resize((28, 28))
    
    # 将图像数据转换为 numpy 数组
    img_array = np.array(img) / 255.0  # 将像素值缩放到 [0, 1] 范围内
    
    # 将图像数据扁平化成一维数组
    img_flat = img_array.flatten()
    
    # 使用模型进行预测
    prediction = model.predict([img_flat])[0]
    
    return render_template('predict.html', prediction=prediction)

if __name__ == '__main__':
    app.run(debug=True)

上述代码创建了一个基本的 Flask 应用,包括两个路由:

  • / 路由用于渲染主页,其中包含一个表单,允许用户上传手写数字图片。

  • /predict 路由用于接收上传的图片并使用模型进行预测。

接下来,你需要创建两个 HTML 模板文件 index.html 和 predict.html,并放置在名为 templates 的文件夹中。index.html 用于渲染主页,而 predict.html 用于显示预测结果。

index.html 内容如下:

html 复制代码
html
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Handwritten Digit Recognition</title>
</head>
<body>
    <h1>Handwritten Digit Recognition</h1>
    <form action="/predict" method="post" enctype="multipart/form-data">
        <input type="file" name="file" accept="image/*">
        <button type="submit">Predict</button>
    </form>
</body>
</html>

现在,你可以运行应用:

bash

bash 复制代码
python app.py

然后在浏览器中访问 http://localhost:5000/,上传手写数字图片并查看预测结果。

相关推荐
小王子10242 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥3 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin3 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客4 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
RZer5 小时前
Hypium+python鸿蒙原生自动化安装配置
python·自动化·harmonyos
CM莫问6 小时前
什么是门控循环单元?
人工智能·pytorch·python·rnn·深度学习·算法·gru
查理零世7 小时前
【算法】回溯算法专题① ——子集型回溯 python
python·算法
圆圆滚滚小企鹅。8 小时前
刷题记录 HOT100回溯算法-6:79. 单词搜索
笔记·python·算法·leetcode
程序猿阿伟8 小时前
《解锁AI黑科技:数据分类聚类与可视化》
人工智能·科技·分类