目录
简介
ND4J主要是JVM的科学计算库,内置了很多计算方法,目的是以最低的RAM需求快速运行。主要特点是:
- 一个多功能的n维数组对象。
- 线性代数和信号处理函数。
-
多平台功能,包括GPU。
- 所有主要操作系统: win/linux/osx/android.
- 架构: x86, arm, ppc.
Nd4j的主要特点是具有多功能的n维阵列接口INDArray。为了提高性能,Nd4j使用堆外内存来存储数据。INDArray不同于标准Java数组。
基础用法
基础信息
创建数值为0的N维数组
java
INDArray xx = Nd4j.zeros(10);
System.out.println(xx);
//运行结果
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INDArray x = Nd4j.zeros(3, 4);
System.out.println(x);
//运行结果
[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]]
同理,也支持三维、四维等,可以通过以下方法进行维度信息的查看:
java
// 数组的轴数(维度)。
int dimensions = x.rank();
// 数组的维数。每个维度的大小。
long[] shape = x.shape();
// 元素的总数。
long length = x.length();
// 数组元素的类型。
DataType dt = x.dataType();
数组创建
要创建INDArray,可以使用ND4J类的静态工厂方法。
Nd4j.createFromArray
函数通过重载,支持double、float、int、long类型,同时对各种类型最高支持四维。
java
double arr_2d[][]={{1.0,2.0,3.0},{4.0,5.0,6.0},{7.0,8.0,9.0}};
INDArray x_2d = Nd4j.createFromArray(arr_2d);
double arr_1d[]={1.0,2.0,3.0};
INDArray x_1d = Nd4j.createFromArray(arr_1d);
可以使用函数zeros
和ones
创建用0和1初始化的数组
可以使用rand函数建用随机值初始化的数组
创建的INDArray的默认数据类型是float,有些重载允许您设置数据类型
java
INDArray x = Nd4j.zeros(5);
//运行结果
//[ 0, 0, 0, 0, 0], FLOAT
int [] shape = {5};
x = Nd4j.zeros(DataType.DOUBLE, 5);
//运行结果
//[ 0, 0, 0, 0, 0], DOUBLE
// 对于更高的维度,可以提供形状数组。二维随机矩阵示例:
int[] shape = {4, 5};
INDArray x = Nd4j.rand(shape);
//运行结果
[[ 0.5669, 0.0576, 0.8701, 0.9598, 0.6470],
[ 0.2711, 0.8427, 0.2819, 0.6617, 0.5109],
[ 0.0602, 0.2674, 0.6586, 0.9939, 0.4781],
[ 0.4099, 0.9503, 0.2227, 0.4738, 0.3759]]
使用**arange
函数**创建一个均匀空间值数组:
java
INDArray x = Nd4j.arange(5);
// [ 0, 1.0000, 2.0000, 3.0000, 4.0000]
INDArray x = Nd4j.arange(2, 7);
// [ 2.0000, 3.0000, 4.0000, 5.0000, 6.0000]
linspace
函数允许您指定生成的点数:
java
//开始数, 停止数, 个数.
INDArray x = Nd4j.linspace(1, 10, 5);
// [ 1.0000, 3.2500, 5.5000, 7.7500, 10.0000]
// 对函数进行多点评估。
import static org.nd4j.linalg.ops.transforms.Transforms.sin;
INDArray x = Nd4j.linspace(0.0, Math.PI, 100, DataType.DOUBLE);
INDArray y = sin(x);
打印数组
如上图的例子,INDArray支持Java的toString()
方法,可以直接通过System.out.println打印除结果
变更维度&堆叠
可以通过reshap函数进行维度的变更
java
int [] shape = {4,3};
//2维数组
x = Nd4j.arange(12).reshape(shape);
/*
[[ 0, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
[ 9.0000, 10.0000, 11.0000]]
*/
如果本身数据的数组不足以填满新的维度,则会报ND4JIllegalStateException异常
java
int [] shape = {4,3};
//2维数组
INDArray x = Nd4j.arange(11).reshape(shape);
//运行结果
//org.nd4j.linalg.exception.ND4JIllegalStateException: New shape length doesn't match original length: [12] vs [11]. Original shape: [11] New Shape: [4, 3]
java
INDArray x = Nd4j.rand(3,4);
x.shape();
// [3, 4]
INDArray x2 = x.ravel(); //转化成一列
x2.shape();
// [12]
INDArray x3 = x.reshape(6,2).shape();
x3.shape();
//[6, 2]
//注意x、x2和x3共享相同的数据。
x2.putScalar(5, -1.0);
System.out.println( x);
/*
[[ 0.0270, 0.3799, 0.5576, 0.3086],
[ 0.2266, -1.0000, 0.1107, 0.4895],
[ 0.8431, 0.6011, 0.2996, 0.7500]]
*/
System.out.println( x2);
// [ 0.0270, 0.3799, 0.5576, 0.3086, 0.2266, -1.0000, 0.1107, 0.4895, 0.8431, 0.6011, 0.2996, 0.7500]
System.out.println( x3);
/*
[[ 0.0270, 0.3799],
[ 0.5576, 0.3086],
[ 0.2266, -1.0000],
[ 0.1107, 0.4895],
[ 0.8431, 0.6011],
[ 0.2996, 0.7500]]
*/
可以使用vstack
和hstack
方法将数组堆叠在一起。
java
INDArray x = Nd4j.rand(2,2);
INDArray y = Nd4j.rand(2,2);
x
/*
[[ 0.1462, 0.5037],
[ 0.1418, 0.8645]]
*/
y;
/*
[[ 0.2305, 0.4798],
[ 0.9407, 0.9735]]
*/
Nd4j.vstack(x, y);
/*
[[ 0.1462, 0.5037],
[ 0.1418, 0.8645],
[ 0.2305, 0.4798],
[ 0.9407, 0.9735]]
*/
Nd4j.hstack(x, y);
/*
[[ 0.1462, 0.5037, 0.2305, 0.4798],
[ 0.1418, 0.8645, 0.9407, 0.9735]]
*/
加减乘除
使用INDArray方法对数组执行操作。
java
加法: arr.add(...), arr.addi(...)
减法: arr.sub(...), arr.subi(...)
乘法: arr.mul(...), arr.muli(...)
除法 : arr.div(...), arr.divi(...)
对数组进行加减乘除操作,有两种类型
【1】in-place :在原有数组基础上变更,如add、sub
【2】复制:创建新的数组,将结果放在新数组中,如addi、subi
java
//复制
//返回一个新数组,并将标量添加到arr的每个元素。
arr_new = arr.add(scalar);
//返回一个新数组,它是arr和其他arr元素级别的加法。
arr_new = arr.add(other_arr);
//In-place
arr_new = arr.addi(scalar);
arr_new = arr.addi(other_arr);
in-place运算符可以方便地将操作链接在一起。尽可能使用(in-place)运算符来提高性能,因为
复制运算符有新的数组创建开销。
注意,执行加减乘除操作时,必须确保基础数据类型相同。
java
int [] shape = {5};
INDArray x = Nd4j.zeros(shape, DataType.DOUBLE);
INDArray x2 = Nd4j.zeros(shape, DataType.INT);
INDArray x3 = x.add(x2);
// java.lang.IllegalArgumentException: Op.X 和 Op.Y must have the same data type, but got INT vs DOUBLE
// 将x2转换为DOUBLE可以解决以下问题:
INDArray x3 = x.add(x2.castTo(DataType.DOUBLE));
累加/最大/最小
INDArray有实现累加/最值操作的方法,如 sum
, min
, max
.
java
int [] shape = {2,3};
INDArray x = Nd4j.rand(shape);
x;
x.sum();
x.min();
x.max();
/*
[[ 0.8621, 0.9224, 0.8407],
[ 0.1504, 0.5489, 0.9584]]
4.2830
0.1504
0.9584
*/
提供维度参数以在指定维度上应用操作:
java
INDArray x = Nd4j.arange(12).reshape(3, 4);
/*
[[ 0, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]]
*/
//每列的总和。
x.sum(0);
//[ 12.0000, 15.0000, 18.0000, 21.0000]
//每行最小值
x.min(1);
//[ 0, 4.0000, 8.0000]
//每行的累计和,
x.cumsum(1);
/*
[[ 0, 1.0000, 3.0000, 6.0000],
[ 4.0000, 9.0000, 15.0000, 22.0000],
[ 8.0000, 17.0000, 27.0000, 38.0000]]
*/
转换操作
Nd4j提供熟悉的数学函数,如sin、cos和exp,这些称为转换操作。结果作为INDArray返回。
java
import static org.nd4j.linalg.ops.transforms.Transforms.exp;
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
INDArray x = Nd4j.arange(3);
// [ 0, 1.0000, 2.0000]
exp(x);
// [ 1.0000, 2.7183, 7.3891]
sqrt(x);
// [ 0, 1.0000, 1.4142]
矩陈乘法
INDArray也支持矩阵运算,如下:
java
INDArray x = Nd4j.arange(12).reshape(3, 4);
/*
[[ 0, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]]
*/
INDArray y = Nd4j.arange(12).reshape(4, 3);
/*
[[ 0, 1.0000, 2.0000],
[ 3.0000, 4.0000, 5.0000],
[ 6.0000, 7.0000, 8.0000],
[ 9.0000, 10.0000, 11.0000]]
*/
// 矩阵乘积.
x.mmul(y);
/*
[[ 42.0000, 48.0000, 54.0000],
[ 114.0000, 136.0000, 158.0000],
[ 186.0000, 224.0000, 262.0000]]
*/
//点积
INDArray x = Nd4j.arange(12);
INDArray y = Nd4j.arange(12);
dot(x, y);
//506.0000
索引/迭代
获取某个位置的数据
java
INDArray x = Nd4j.arange(12);
// [ 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000]
//单一元素访问。其他方法: getDouble, getInt, ...
float f = x.getFloat(3);
// 3.0
转化为Java数组
java
//转换为Java数组。
float [] fArr = x.toFloatVector();
// [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]
获取指定区间的数据
java
INDArray x2 = x.get(NDArrayIndex.interval(2, 6));
// [ 2.0000, 3.0000, 4.0000, 5.0000]
在x的副本上:从开始到位置6(不包括),将每2个元素设置为-1.0
java
//在x的副本上:从开始到位置6(不包括),将每2个元素设置为-1.0
INDArray y = x.dup();
y.get(NDArrayIndex.interval(0, 2, 6)).assign(-1.0);
//[ -1.0000, 1.0000, -1.0000, 3.0000, -1.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000]
y的反向副本
java
//y的反向副本。
INDArray y2 = Nd4j.reverse(y.dup());
//[ 11.0000, 10.0000, 9.0000, 8.0000, 7.0000, 6.0000, 5.0000, -1.0000, 3.0000, -1.0000, 1.0000, -1.0000]
对于多维数组,应该使用INDArray.get(NDArrayIndex...)
。下面的示例演示如何遍历二维数组的行和列。注意,对于2D数组,我们可以使用getColumn
和getRow
便利方法。
java
// 在2d数组的行和列上迭代。
int rows = 4;
int cols = 5;
int[] shape = {rows, cols};
INDArray x = Nd4j.rand(shape);
/*
[[ 0.2228, 0.2871, 0.3880, 0.7167, 0.9951],
[ 0.7181, 0.8106, 0.9062, 0.9291, 0.5115],
[ 0.5483, 0.7515, 0.3623, 0.7797, 0.5887],
[ 0.6822, 0.7785, 0.4456, 0.4231, 0.9157]]
*/
for (int row=0; row<rows; row++) {
INDArray y = x.get(NDArrayIndex.point(row), NDArrayIndex.all());
}
/*
[ 0.2228, 0.2871, 0.3880, 0.7167, 0.9951]
[ 0.7181, 0.8106, 0.9062, 0.9291, 0.5115]
[ 0.5483, 0.7515, 0.3623, 0.7797, 0.5887]
[ 0.6822, 0.7785, 0.4456, 0.4231, 0.9157]
*/
for (int col=0; col<cols; col++) {
INDArray y = x.get(NDArrayIndex.all(), NDArrayIndex.point(col));
}
/*
[ 0.2228, 0.7181, 0.5483, 0.6822]
[ 0.2871, 0.8106, 0.7515, 0.7785]
[ 0.3880, 0.9062, 0.3623, 0.4456]
[ 0.7167, 0.9291, 0.7797, 0.4231]
[ 0.9951, 0.5115, 0.5887, 0.9157]
*/
深拷贝/引用传递/视图
引用传递
以下用法,仅仅只是指向x的指针,进行了引用传递,并未复制
java
INDArray x = Nd4j.rand(2,2);
//y和x指向同一个INDArray对象。
INDArray y = x;
视图
一些函数将返回数组的视图,并未进行数组的复制
java
INDArray x = Nd4j.rand(3,4);
INDArray x2 = x.ravel();
INDArray x3 = x.reshape(6,2);
// 修改 x, x2 和 x3
x2.putScalar(5, -1.0);
x
/*
[[ 0.8546, 0.1509, 0.0331, 0.1308],
[ 0.1753, -1.0000, 0.2277, 0.1998],
[ 0.2741, 0.8257, 0.6946, 0.6851]]
*/
x2
// [ 0.8546, 0.1509, 0.0331, 0.1308, 0.1753, -1.0000, 0.2277, 0.1998, 0.2741, 0.8257, 0.6946, 0.6851]
x3
/*
[[ 0.8546, 0.1509],
[ 0.0331, 0.1308],
[ 0.1753, -1.0000],
[ 0.2277, 0.1998],
[ 0.2741, 0.8257],
[ 0.6946, 0.6851]]
*/
深拷贝
要复制数组,请使用dup
方法。这将为您提供一个包含新数据的新数组。
java
INDArray x = Nd4j.rand(3,4);
INDArray x2 = x.ravel().dup();
//现在只改变x2。
x2.putScalar(5, -1.0);
x
/*
[[ 0.1604, 0.0322, 0.8910, 0.4604],
[ 0.7724, 0.1267, 0.1617, 0.7586],
[ 0.6117, 0.5385, 0.1251, 0.6886]]
*/
x2
// [ 0.1604, 0.0322, 0.8910, 0.4604, 0.7724, -1.0000, 0.1617, 0.7586, 0.6117, 0.5385, 0.1251, 0.6886]
其它
ND4J还有很多科学计算函数,具体可以查看文档