【神经网络】js版本的Pytorch,estorch重磅发布

文章目录

EsTorch

类似 PyTorch 的 JavaScript (ECMAScript) 张量/自动微分库,支持 CPU 训练。零依赖,纯 CommonJS 模块。

安装

bash 复制代码
npm install estorch

快速开始

javascript 复制代码
const torch = require('estorch');

// 创建张量
const a = torch.tensor([[1, 2], [3, 4]], { requiresGrad: true });
const b = torch.tensor([[5, 6], [7, 8]], { requiresGrad: true });

// 支持自动微分的运算
const c = a.matmul(b);
const loss = c.sum();

// 反向传播
loss.backward();
console.log(a.grad); // 梯度自动计算

API 参考

张量操作

函数 / 方法 说明
torch.tensor(data, options?) 从嵌套数组创建张量。选项:{ shape, requiresGrad, device }
torch.randn(shape, options?) 创建标准正态分布随机张量
torch.zeros(shape, options?) 创建全零张量
tensor.add(other) 逐元素加法
tensor.sub(other) 逐元素减法
tensor.mul(other) 逐元素或标量乘法
tensor.matmul(other) 矩阵乘法
tensor.tanh() Tanh 激活函数
tensor.relu() ReLU 激活函数
tensor.argmax(dim) 沿指定维度取最大值索引(支持二维张量的 dim=1
tensor.reshape(...shape) / tensor.view(...shape) 重塑张量形状
tensor.flatten(startDim?) startDim 开始展平
tensor.clone() 深拷贝
tensor.detach() 从计算图中分离
tensor.backward() 通过反向传播计算梯度
tensor.size(dim?) 获取形状或指定维度的大小
tensor.item() 获取标量值(仅限 0 维张量)
tensor.to(device) / tensor.cpu() 将张量移动到指定设备
torch.noGrad(fn) 在禁用梯度计算的情况下执行 fn

神经网络 (torch.nn)

类 / 函数 说明
nn.Module 所有模块的基类。提供 parameters()train()eval()stateDict()loadStateDict()to(device)
nn.Parameter(data, shape) 张量子类,自动注册为模块参数
nn.Linear(inFeatures, outFeatures) 全连接层
nn.Conv2d(inChannels, outChannels, kernelSize, options?) 二维卷积层。选项:{ stride, padding }
nn.AvgPool2d(kernelSize, options?) 二维平均池化层。选项:{ stride }
nn.CrossEntropyLoss() 交叉熵损失函数
函数式 API (torch.nn.functional)
函数 说明
functional.relu(x) ReLU 激活
functional.tanh(x) Tanh 激活
functional.maxPool2d(x, kernelSize, stride?) 最大池化
functional.avgPool2d(x, kernelSize, stride?) 平均池化
functional.crossEntropy(logits, labels) 交叉熵损失

优化器 (torch.optim)

说明
optim.Adam(params, options?) Adam 优化器。选项:{ lr, betas, eps }。方法:zeroGrad()step()stateDict()
optim.lrScheduler.StepLR(optimizer, options?) 阶梯式学习率调度器。选项:{ stepSize, gamma }

数据加载 (torch.data)

函数 / 类 说明
data.DataLoader(dataset, options?) 批量数据加载器,支持异步迭代器。选项:{ batchSize, shuffle, device }
data.randomSplit(dataset, lengths) 随机拆分数据集为子集
data.mnist.MNISTDataset(imagesPath, labelsPath, options?) MNIST 数据集类。选项:{ synthetic, tensorFactory, limit }
data.mnist.prepareMNIST(root, options?) 下载/准备 MNIST 数据。选项:{ download, synthetic }
data.mnist.downloadMNIST(root, options?) 下载 MNIST IDX 文件

工具函数

函数 说明
torch.device(type) 创建设备描述符。当前仅支持 "cpu"
torch.save(obj, path) 将对象保存为 JSON(仅 Node.js)
torch.load(path) 从文件加载 JSON 对象(仅 Node.js)

MNIST 训练示例

完整用户示例见 examples/mnist/main.js

bash 复制代码
# 克隆仓库后使用合成数据运行(无需下载)
git clone <your-repo-url>
cd estorch
node examples/mnist/main.js --synthetic

# 或使用真实 MNIST 数据
yarn mnist:download
node examples/mnist/main.js

示例参数

bash 复制代码
node examples/mnist/main.js --synthetic --epochs 3 --batch-size 64
node examples/mnist/main.js --data-dir ./datasets/mnist --lr 0.002 --train-limit 5000
参数 默认值 说明
--synthetic false 使用合成数据(无需下载)
--data-dir <path> ./datasets/mnist MNIST IDX 文件路径
--epochs <n> 1 训练轮数
--batch-size <n> 64 批大小
--lr <n> 0.001 学习率
--train-limit <n> 48000 训练样本数量限制
--val-limit <n> 12000 验证样本数量限制
--test-limit <n> 10000 测试样本数量限制

冒烟测试

bash 复制代码
yarn test:smoke

说明

  • 运行环境:Node.js >= 14.0.0
  • 模块系统 :CommonJS(require / module.exports
  • 后端:仅 CPU(WebGPU 支持为实验性功能,未包含在本包中)
  • 零依赖:无需外部包
  • torch.save()torch.load() 使用 fs 模块,仅在 Node.js 环境中可用

许可证

MIT

相关推荐
蔡俊锋1 小时前
AI自动化不是接工具就行,得补缺点搭轨道
人工智能·ai 效率
贫民窟的勇敢爷们1 小时前
Vue的渐进式特性,让前端开发更具灵活性
前端·javascript·vue.js
DXM05211 小时前
第11期:实战| ArcGIS Pro 遥感影像预处理
人工智能·arcgis·#arcpy·#arcgis 二次开发·#gis 自动化
Tutankaaa1 小时前
交通安全知识竞赛:文明出行,安全相伴
大数据·人工智能·安全
knight_9___1 小时前
大模型project面试2
人工智能
问心无愧05131 小时前
ctf show web入门81
前端·笔记
龙侠九重天1 小时前
大型语言模型结构化输出:用 JSON Schema 约束大模型输出
人工智能·语言模型·自然语言处理·大模型·json
China_Yanhy1 小时前
【云原生 AI 实战】EKS 搭建 GPU 超算集群:从零拉起节点到 PyTorchJob 分布式训练 (附 EFA 加速避坑指南)
人工智能·分布式·云原生
人工智能培训1 小时前
知识图谱与检索增强的实战结合
人工智能·深度学习·神经网络·机器学习·生成对抗网络