batch norm 中 track_running_stats 的探索

复制代码
if self.track_running_stats:
    self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
    self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
    self.running_mean: Optional[Tensor]
    self.running_var: Optional[Tensor]
    self.register_buffer('num_batches_tracked',
                         torch.tensor(0, dtype=torch.long,
                                      **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
    self.num_batches_tracked: Optional[Tensor]
else:
    self.register_buffer("running_mean", None)
    self.register_buffer("running_var", None)
    self.register_buffer("num_batches_tracked", None)

基于条件 self.track_running_stats,self 对象执行了一系列的操作来注册缓冲区(buffer)和属性。

如果 self.track_running_stats 为 True,表示正在跟踪运行时统计信息,那么执行以下操作:

使用 self.register_buffer() 方法注册缓冲区 running_mean,其值为全零的张量,形状为 (num_features,)。num_features 是一个变量,表示特征的数量。**factory_kwargs 是一个包含其他关键字参数的字典,用于创建张量。类似地,使用 self.register_buffer() 方法注册缓冲区 running_var,其值为全一的张量,形状与 running_mean 相同。

register a buffer that should not to be considered a model parameter. are persistent and will be saved alongside parameters

定义属性 self.running_mean 和 self.running_var,它们的类型是 Optional[Tensor],即可选的张量类型。这些属性用于存储跟踪的运行时均值和方差。

使用 self.register_buffer() 方法注册缓冲区 num_batches_tracked,其值为一个长整型张量,初始值为0。这个缓冲区用于跟踪已处理的批次数量。

定义属性 self.num_batches_tracked,也是一个 Optional[Tensor] 类型,用于存储已处理的批次数量。

如果 self.track_running_stats 为 False,表示不跟踪运行时统计信息,那么执行以下操作:

使用 self.register_buffer() 方法分别注册缓冲区 running_mean、running_var 和 num_batches_tracked,它们的值都为 None,即空值。

这些操作的目的是根据条件设置合适的缓冲区和属性,以便在模型的训练和推理过程中进行运行时统计信息的跟踪和更新。如果跟踪统计信息,则使用缓冲区存储相关的均值、方差和已处理的批次数量;否则,这些属性被设置为 None。

torch._six.string_classes 是一个字符串类的元组,用于在 PyTorch 内部处理字符串类型的兼容性。它是一个内部使用的变量,通常不需要在用户的代码中直接使用。

在 PyTorch 中,torch._six.string_classes 用于处理字符串类型的兼容性问题,尤其是在不同的 Python 版本或不同的运行环境中。它定义了一组字符串类,以便在不同的环境中都能正确处理字符串的操作。

该元组包含了多个字符串类,例如:

str:Python 3.x 中的字符串类型。

unicode:Python 2.x 中的字符串类型。

bytes:Python 2.x 和 3.x 中的字节串类型。

通过使用 torch._six.string_classes,PyTorch 可以在不同的 Python 版本中兼容地处理字符串类型,以确保代码的可移植性和兼容性。

需要注意的是,由于 torch._six.string_classes 是一个内部使用的变量,它的具体内容和实现可能会在不同的 PyTorch 版本中有所变化。因此,建议在用户代码中使用标准的字符串类型,如 str,而不是直接依赖于 torch._six.string_classes。

相关推荐
2501_9411118217 小时前
使用Scikit-learn进行机器学习模型评估
jvm·数据库·python
木头左18 小时前
自适应门控循环单元GRU-O与标准LSTM在量化交易策略中的性能对比实验
深度学习·gru·lstm
哥布林学者18 小时前
吴恩达深度学习课程二: 改善深层神经网络 第三周:超参数调整,批量标准化和编程框架(三)多值预测与多分类
深度学习·ai
小呀小萝卜儿18 小时前
2025-11-14 学习记录--Python-使用sklearn+检测 .csv 文件的编码+读取 .csv 文件
python·学习
月下倩影时18 小时前
视觉学习篇——模型推理部署:从“炼丹”到“上桌”
人工智能·深度学习·学习
java1234_小锋18 小时前
[免费]基于python的Flask+Vue医疗疾病数据分析大屏可视化系统(机器学习随机森林算法+requests)【论文+源码+SQL脚本】
python·机器学习·数据分析·flask·疾病数据分析
高洁0118 小时前
国内外具身智能VLA模型深度解析(2)国外典型具身智能VLA架构
深度学习·算法·aigc·transformer·知识图谱
小殊小殊18 小时前
从零手撸Mamba!
人工智能·深度学习
MediaTea20 小时前
Python 第三方库:cv2(OpenCV 图像处理与计算机视觉库)
开发语言·图像处理·python·opencv·计算机视觉
江塘20 小时前
机器学习-决策树多种生成方法讲解及实战代码讲解(C++/Python实现)
c++·python·决策树·机器学习