文章目录
- EsTorch
-
- 安装
- 快速开始
- [API 参考](#API 参考)
-
- 张量操作
- 神经网络 (`torch.nn`)
-
- [函数式 API (`torch.nn.functional`)](#函数式 API (
torch.nn.functional))
- [函数式 API (`torch.nn.functional`)](#函数式 API (
- 优化器 (`torch.optim`)
- 数据加载 (`torch.data`)
- 工具函数
- [MNIST 训练示例](#MNIST 训练示例)
- 冒烟测试
- 说明
- 许可证
EsTorch
类似 PyTorch 的 JavaScript (ECMAScript) 张量/自动微分库,支持 CPU 训练。零依赖,纯 CommonJS 模块。

安装
bash
npm install estorch
快速开始
javascript
const torch = require('estorch');
// 创建张量
const a = torch.tensor([[1, 2], [3, 4]], { requiresGrad: true });
const b = torch.tensor([[5, 6], [7, 8]], { requiresGrad: true });
// 支持自动微分的运算
const c = a.matmul(b);
const loss = c.sum();
// 反向传播
loss.backward();
console.log(a.grad); // 梯度自动计算
API 参考
张量操作
| 函数 / 方法 | 说明 |
|---|---|
torch.tensor(data, options?) |
从嵌套数组创建张量。选项:{ shape, requiresGrad, device } |
torch.randn(shape, options?) |
创建标准正态分布随机张量 |
torch.zeros(shape, options?) |
创建全零张量 |
tensor.add(other) |
逐元素加法 |
tensor.sub(other) |
逐元素减法 |
tensor.mul(other) |
逐元素或标量乘法 |
tensor.matmul(other) |
矩阵乘法 |
tensor.tanh() |
Tanh 激活函数 |
tensor.relu() |
ReLU 激活函数 |
tensor.argmax(dim) |
沿指定维度取最大值索引(支持二维张量的 dim=1) |
tensor.reshape(...shape) / tensor.view(...shape) |
重塑张量形状 |
tensor.flatten(startDim?) |
从 startDim 开始展平 |
tensor.clone() |
深拷贝 |
tensor.detach() |
从计算图中分离 |
tensor.backward() |
通过反向传播计算梯度 |
tensor.size(dim?) |
获取形状或指定维度的大小 |
tensor.item() |
获取标量值(仅限 0 维张量) |
tensor.to(device) / tensor.cpu() |
将张量移动到指定设备 |
torch.noGrad(fn) |
在禁用梯度计算的情况下执行 fn |
神经网络 (torch.nn)
| 类 / 函数 | 说明 |
|---|---|
nn.Module |
所有模块的基类。提供 parameters()、train()、eval()、stateDict()、loadStateDict()、to(device) |
nn.Parameter(data, shape) |
张量子类,自动注册为模块参数 |
nn.Linear(inFeatures, outFeatures) |
全连接层 |
nn.Conv2d(inChannels, outChannels, kernelSize, options?) |
二维卷积层。选项:{ stride, padding } |
nn.AvgPool2d(kernelSize, options?) |
二维平均池化层。选项:{ stride } |
nn.CrossEntropyLoss() |
交叉熵损失函数 |
函数式 API (torch.nn.functional)
| 函数 | 说明 |
|---|---|
functional.relu(x) |
ReLU 激活 |
functional.tanh(x) |
Tanh 激活 |
functional.maxPool2d(x, kernelSize, stride?) |
最大池化 |
functional.avgPool2d(x, kernelSize, stride?) |
平均池化 |
functional.crossEntropy(logits, labels) |
交叉熵损失 |
优化器 (torch.optim)
| 类 | 说明 |
|---|---|
optim.Adam(params, options?) |
Adam 优化器。选项:{ lr, betas, eps }。方法:zeroGrad()、step()、stateDict() |
optim.lrScheduler.StepLR(optimizer, options?) |
阶梯式学习率调度器。选项:{ stepSize, gamma } |
数据加载 (torch.data)
| 函数 / 类 | 说明 |
|---|---|
data.DataLoader(dataset, options?) |
批量数据加载器,支持异步迭代器。选项:{ batchSize, shuffle, device } |
data.randomSplit(dataset, lengths) |
随机拆分数据集为子集 |
data.mnist.MNISTDataset(imagesPath, labelsPath, options?) |
MNIST 数据集类。选项:{ synthetic, tensorFactory, limit } |
data.mnist.prepareMNIST(root, options?) |
下载/准备 MNIST 数据。选项:{ download, synthetic } |
data.mnist.downloadMNIST(root, options?) |
下载 MNIST IDX 文件 |
工具函数
| 函数 | 说明 |
|---|---|
torch.device(type) |
创建设备描述符。当前仅支持 "cpu" |
torch.save(obj, path) |
将对象保存为 JSON(仅 Node.js) |
torch.load(path) |
从文件加载 JSON 对象(仅 Node.js) |
MNIST 训练示例
完整用户示例见 examples/mnist/main.js。
bash
# 克隆仓库后使用合成数据运行(无需下载)
git clone <your-repo-url>
cd estorch
node examples/mnist/main.js --synthetic
# 或使用真实 MNIST 数据
yarn mnist:download
node examples/mnist/main.js
示例参数
bash
node examples/mnist/main.js --synthetic --epochs 3 --batch-size 64
node examples/mnist/main.js --data-dir ./datasets/mnist --lr 0.002 --train-limit 5000
| 参数 | 默认值 | 说明 |
|---|---|---|
--synthetic |
false | 使用合成数据(无需下载) |
--data-dir <path> |
./datasets/mnist |
MNIST IDX 文件路径 |
--epochs <n> |
1 | 训练轮数 |
--batch-size <n> |
64 | 批大小 |
--lr <n> |
0.001 | 学习率 |
--train-limit <n> |
48000 | 训练样本数量限制 |
--val-limit <n> |
12000 | 验证样本数量限制 |
--test-limit <n> |
10000 | 测试样本数量限制 |
冒烟测试
bash
yarn test:smoke
说明
- 运行环境:Node.js >= 14.0.0
- 模块系统 :CommonJS(
require/module.exports) - 后端:仅 CPU(WebGPU 支持为实验性功能,未包含在本包中)
- 零依赖:无需外部包
torch.save()和torch.load()使用fs模块,仅在 Node.js 环境中可用
许可证
MIT