深度学习3.2 线性回归的从零开始实现

3.2.1 生成数据集

python 复制代码
%matplotlib inline
import random
import torch
from d2l import torch as d2l

def synthetic_data(w, b, num_examples):
    # 生成特征矩阵X,形状为(num_examples, len(w)),符合标准正态分布
    X = torch.normal(0, 1, (num_examples, len(w)))
    # 计算标签y = Xw + b
    y = torch.matmul(X, w) + b
    # 添加均值为0、标准差为0.01的噪声
    y += torch.normal(0, 0.01, y.shape)
    # 将y转换为列向量(形状:num_examples × 1)
    return X, y.reshape((-1, 1))
python 复制代码
true_w = torch.tensor([2, -3.4])  # 定义真实权重
true_b = 4.2                      # 定义真实偏置
features, labels = synthetic_data(true_w, true_b, 1000)  # 生成1000个样本

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)

features[:, 1]: 选取所有样本的第二个特征(索引为1的列)。

3.2.1 读取数据集

python 复制代码
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))

    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

tensor([[ 1.6556, 0.1851],

-1.4880, 0.0684\], \[ 1.0536, 0.9818\], \[-0.7794, -1.9199\], \[-0.3383, 0.2244\], \[-0.2260, 3.1530\], \[-2.3626, 1.1877\], \[-0.3301, 0.1781\], \[-0.6136, -1.2974\], \[-0.3397, -0.2088\]\]) tensor(\[\[ 6.8888\], \[ 0.9887\], \[ 2.9757\], \[ 9.1748\], \[ 2.7541\], \[-6.9671\], \[-4.5522\], \[ 2.9436\], \[ 7.3728\], \[ 4.2270\]\])

相关推荐
珠海西格电力15 小时前
零碳园区数字感知基础架构规划:IoT 设备布点与传输管网衔接设计
大数据·运维·人工智能·物联网·智慧城市·能源
AI即插即用15 小时前
即插即用系列 | WACV 2024 D-LKA:超越 Transformer?D-LKA Net 如何用可变形大核卷积刷新医学图像分割
图像处理·人工智能·深度学习·目标检测·计算机视觉·视觉检测·transformer
草莓熊Lotso15 小时前
《算法闯关指南:动态规划算法--斐波拉契数列模型》--04.解码方法
c++·人工智能·算法·动态规划
winfredzhang15 小时前
深入剖析 wxPython 配置文件编辑器
python·编辑器·wxpython·ini配置
狂放不羁霸15 小时前
电子科技大学2025年机器学习期末考试回忆
人工智能·机器学习
多恩Stone15 小时前
【3DV 进阶-9】Hunyuan3D2.1 中的 MoE
人工智能·pytorch·python·算法·aigc
爱打代码的小林15 小时前
网络爬虫基础
爬虫·python
B站计算机毕业设计之家15 小时前
大数据项目:基于python电商平台用户行为数据分析可视化系统 电商订单数据分析 Django框架 Echarts可视化 大数据技术(建议收藏)
大数据·python·机器学习·数据分析·django·电商·用户分析
weixin_4215850115 小时前
静态图(Static Graph) vs 动态执行(Eager Execution)
python
Chase_______15 小时前
AI 提升效率指南:如何高效书写提示词
人工智能·ai·prompt