深度学习-最简单的Demo-直接运行

根据动手学深度学习第一个最简单的Demo,通过此demo旨在了解深度学习都干了什么事情,为什么要做这些事情,便于后续理解更加复杂的神经网络训练

python 复制代码
import torch
import random

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        random_index = indices[i:min(i + batch_size, num_examples)]
        batch_indices = torch.tensor(random_index)
        yield features[batch_indices], lables[batch_indices]

def linreg(X, w, b):
    return torch.matmul(X, w) + b

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

true_w = torch.tensor([2, -3.4])
true_b = 4.3
batch_size = 10
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
features, lables = synthetic_data(true_w, true_b, 1000)
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, lables):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w, b], lr, batch_size)
    
    with torch.no_grad():
        train_l = loss(net(features, w, b), lables)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}', f'w={w}, b={b}')

结果:

python 复制代码
epoch 1, loss 0.042685 w=tensor([[ 1.8874],
        [-3.2286]], requires_grad=True), b=tensor([4.0865], requires_grad=True)
epoch 2, loss 0.000169 w=tensor([[ 1.9937],
        [-3.3907]], requires_grad=True), b=tensor([4.2893], requires_grad=True)
epoch 3, loss 0.000054 w=tensor([[ 2.0000],
        [-3.3992]], requires_grad=True), b=tensor([4.2994], requires_grad=True)

能看到,最者不断的训练,模型的参数,逐渐靠近我们模拟数据集的原始参数。

相关推荐
DeepFlow 零侵扰全栈可观测5 分钟前
DeepFlow 全栈可观测性 护航某银行核心系统全生命周期
数据库·人工智能·分布式·云原生·金融
哈哈哈也不行吗13 分钟前
从入门到精通:大角几何在教学中的深度应用
人工智能·几何画板·几何绘图·大角几何·数学绘图工具
也不知秋13 分钟前
巧用 AI 提升 Excel 工作效率
人工智能
_codemonster21 分钟前
深度学习实战(基于pytroch)系列(四十三)深度循环神经网络pytorch实现
pytorch·rnn·深度学习
szxinmai主板定制专家24 分钟前
基于x86和ARM的EtherCAT运动控制器,最大支持32轴,支持codesys和实时系统优化
arm开发·人工智能·嵌入式硬件·yolo
JarryStudy28 分钟前
自动调优在Triton-on-Ascend中的应用:从参数优化到性能极致挖掘
人工智能·算法·昇腾·cann·ascend c
TTGGGFF32 分钟前
AI 十大论文精讲(二):GPT-3 论文全景解析——大模型 + 提示词如何解锁 “举一反三” 能力?
人工智能·gpt-3
智能化咨询34 分钟前
(66页PPT)高校智慧校园解决方案(附下载方式)
大数据·数据库·人工智能
腾飞开源35 分钟前
08_Spring AI 干货笔记之结构化输出
人工智能·spring ai·数据类型转换·结构化输出·ai模型集成·输出转换器·json模式
轮到我狗叫了36 分钟前
Contrastive pseudo learning for openworld deepfake attribution 超细致论文笔记,第一次读论文
人工智能