基本运算中,包括add, sub, mul, div, neg等函数,以及这些函数的带下划线的版本add_, sub_, mul_, div_, neg_, 其中带下划线的版本为修改原数据。
代码
python
import torch
import numpy as np
# 不修改原数据的计算
def test01():
data = torch.randint(0, 10, [2, 3]) # 开始值,结束值,形状
print(data)
# 计算完成之后,会返回一个新的张量
data = data.add(10)
print(data)
# data.sub()
# data.mul()
# data.div()
# data.neg() 取相反数
# 修改原数据的计算 (inplace方式的计算)
def test02():
data = torch.randint(0, 10, [2, 3]) # 开始值,结束值,形状
print(data)
# 带下划线的版本的函数直接修改原数据,不需要用新的变量保存
data.add_(10) # inplace=True
print(data)
# data.sub_()
# data.mul_()
# data.div_()
# data.neg_() 取相反数
if __name__ == "__main__":
test02()