【神经网络】js版本的Pytorch,estorch重磅发布

文章目录

EsTorch

类似 PyTorch 的 JavaScript (ECMAScript) 张量/自动微分库,支持 CPU 训练。零依赖,纯 CommonJS 模块。

安装

bash 复制代码
npm install estorch

快速开始

javascript 复制代码
const torch = require('estorch');

// 创建张量
const a = torch.tensor([[1, 2], [3, 4]], { requiresGrad: true });
const b = torch.tensor([[5, 6], [7, 8]], { requiresGrad: true });

// 支持自动微分的运算
const c = a.matmul(b);
const loss = c.sum();

// 反向传播
loss.backward();
console.log(a.grad); // 梯度自动计算

API 参考

张量操作

函数 / 方法 说明
torch.tensor(data, options?) 从嵌套数组创建张量。选项:{ shape, requiresGrad, device }
torch.randn(shape, options?) 创建标准正态分布随机张量
torch.zeros(shape, options?) 创建全零张量
tensor.add(other) 逐元素加法
tensor.sub(other) 逐元素减法
tensor.mul(other) 逐元素或标量乘法
tensor.matmul(other) 矩阵乘法
tensor.tanh() Tanh 激活函数
tensor.relu() ReLU 激活函数
tensor.argmax(dim) 沿指定维度取最大值索引(支持二维张量的 dim=1
tensor.reshape(...shape) / tensor.view(...shape) 重塑张量形状
tensor.flatten(startDim?) startDim 开始展平
tensor.clone() 深拷贝
tensor.detach() 从计算图中分离
tensor.backward() 通过反向传播计算梯度
tensor.size(dim?) 获取形状或指定维度的大小
tensor.item() 获取标量值(仅限 0 维张量)
tensor.to(device) / tensor.cpu() 将张量移动到指定设备
torch.noGrad(fn) 在禁用梯度计算的情况下执行 fn

神经网络 (torch.nn)

类 / 函数 说明
nn.Module 所有模块的基类。提供 parameters()train()eval()stateDict()loadStateDict()to(device)
nn.Parameter(data, shape) 张量子类,自动注册为模块参数
nn.Linear(inFeatures, outFeatures) 全连接层
nn.Conv2d(inChannels, outChannels, kernelSize, options?) 二维卷积层。选项:{ stride, padding }
nn.AvgPool2d(kernelSize, options?) 二维平均池化层。选项:{ stride }
nn.CrossEntropyLoss() 交叉熵损失函数
函数式 API (torch.nn.functional)
函数 说明
functional.relu(x) ReLU 激活
functional.tanh(x) Tanh 激活
functional.maxPool2d(x, kernelSize, stride?) 最大池化
functional.avgPool2d(x, kernelSize, stride?) 平均池化
functional.crossEntropy(logits, labels) 交叉熵损失

优化器 (torch.optim)

说明
optim.Adam(params, options?) Adam 优化器。选项:{ lr, betas, eps }。方法:zeroGrad()step()stateDict()
optim.lrScheduler.StepLR(optimizer, options?) 阶梯式学习率调度器。选项:{ stepSize, gamma }

数据加载 (torch.data)

函数 / 类 说明
data.DataLoader(dataset, options?) 批量数据加载器,支持异步迭代器。选项:{ batchSize, shuffle, device }
data.randomSplit(dataset, lengths) 随机拆分数据集为子集
data.mnist.MNISTDataset(imagesPath, labelsPath, options?) MNIST 数据集类。选项:{ synthetic, tensorFactory, limit }
data.mnist.prepareMNIST(root, options?) 下载/准备 MNIST 数据。选项:{ download, synthetic }
data.mnist.downloadMNIST(root, options?) 下载 MNIST IDX 文件

工具函数

函数 说明
torch.device(type) 创建设备描述符。当前仅支持 "cpu"
torch.save(obj, path) 将对象保存为 JSON(仅 Node.js)
torch.load(path) 从文件加载 JSON 对象(仅 Node.js)

MNIST 训练示例

完整用户示例见 examples/mnist/main.js

bash 复制代码
# 克隆仓库后使用合成数据运行(无需下载)
git clone <your-repo-url>
cd estorch
node examples/mnist/main.js --synthetic

# 或使用真实 MNIST 数据
yarn mnist:download
node examples/mnist/main.js

示例参数

bash 复制代码
node examples/mnist/main.js --synthetic --epochs 3 --batch-size 64
node examples/mnist/main.js --data-dir ./datasets/mnist --lr 0.002 --train-limit 5000
参数 默认值 说明
--synthetic false 使用合成数据(无需下载)
--data-dir <path> ./datasets/mnist MNIST IDX 文件路径
--epochs <n> 1 训练轮数
--batch-size <n> 64 批大小
--lr <n> 0.001 学习率
--train-limit <n> 48000 训练样本数量限制
--val-limit <n> 12000 验证样本数量限制
--test-limit <n> 10000 测试样本数量限制

冒烟测试

bash 复制代码
yarn test:smoke

说明

  • 运行环境:Node.js >= 14.0.0
  • 模块系统 :CommonJS(require / module.exports
  • 后端:仅 CPU(WebGPU 支持为实验性功能,未包含在本包中)
  • 零依赖:无需外部包
  • torch.save()torch.load() 使用 fs 模块,仅在 Node.js 环境中可用

许可证

MIT

相关推荐
橙子家4 小时前
浏览器缓存之【身份与会话管理】:Cookies 和 Private state tokens
前端
To_OC5 小时前
LC 49 字母异位词分组:想到哈希表很简单,选对 key 才是精髓
javascript·算法·leetcode
最新资讯动态5 小时前
HDC 2026 | 对话鲸鸿动能:存量时代,品牌如何夺回营销“主动权”?
前端
最新资讯动态5 小时前
游戏出海,从产品走向体系
前端
最新资讯动态5 小时前
20人团队跑出百万DAU、大厂也来抢量:谁在鸿蒙生态跑出加速度
前端
最新资讯动态6 小时前
千万开发者背后,鸿蒙商业化的B面
前端
冬奇Lab7 小时前
每日一个开源项目(第140篇):AgentScope 2.0 - 阿里开源的生产级 Agent 框架
人工智能·开源·agent
冬奇Lab7 小时前
Skill 系列(04):Skill 指标体系——L1/L2/L3 三层监控,让质量下降有据可查
人工智能·开源·llm
爱勇宝7 小时前
AI 时代:智商决定起点,情商决定走多远
前端·ai编程
kyriewen8 小时前
用了半年 Claude Code 后,我尝试关掉它写了一周代码——结果比想象中严重
前端·javascript·ai编程