Python 神经网络项目常用语法

  • [一、 工具](#一、 工具)
    • [1. 导入模块和包](#1. 导入模块和包)
    • [2. 修改系统路径 (sys.path.append)](#2. 修改系统路径 (sys.path.append))
    • [3. 命令行参数解析 (argparse 模块)](#3. 命令行参数解析 (argparse 模块))
    • [4. 字符串格式化](#4. 字符串格式化)
    • [5. 获取设备](#5. 获取设备)
    • [6. 加载数据](#6. 加载数据)
    • 7.设置推理结果的子路径
    • [8. main() 脚本入口点](#8. main() 脚本入口点)
  • 二、类相关
    • [1. 类的定义及初始化](#1. 类的定义及初始化)
    • [2. 类的实例化及函数调用](#2. 类的实例化及函数调用)
  • 三、神经网络常用类
    • [1. 工具类](#1. 工具类)
      • [1.1 正弦位置编码类](#1.1 正弦位置编码类)
      • [1.2 上采样/下采样](#1.2 上采样/下采样)
      • [1.3 标准化](#1.3 标准化)
      • [1.4 归一化层](#1.4 归一化层)
      • [1.5 提取函数 extract](#1.5 提取函数 extract)
    • [2. 构建网络块类](#2. 构建网络块类)
      • [2.1 block()](#2.1 block())
      • [2.2 残差连接](#2.2 残差连接)
      • [2.3 ResnetBlock](#2.3 ResnetBlock)
      • [2.3 Attention](#2.3 Attention)

一、 工具

1. 导入模块和包

python 复制代码
import os
import argparse
import sys as s
from accelerate import Accelerator
  • import:用于导入模块和包,可以选择导入单个模块、多个模块。
  • from ... import ...从特定模块中导入具体的类、函数或变量。
  • impoert ... as ...:可以为导入的模块指定别名,使代码简洁。

2. 修改系统路径 (sys.path.append)

python 复制代码
# 返回当前脚本所在目录的父目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 

# 返回当前脚本所在目录的上上级目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

这两行代码通过 os.path.dirname(__file__) 获取当前脚本所在的目录,然后使用 .. 依次返回上一级目录。

  • sys.path.append()修改模块搜索路径 ,可以动态添加额外的搜索路径,以便访问其他目录中的模块。
  • os.path.join(a, b, ...)将多个路径部分连接成一个完整的路径 ,并确保使用正确的路径分隔符 (在 UNIX 系统上是 /,在 Windows 上是 \)。
    • 例如,os.path.join("/home/user", "project") 将返回 /home/user/project
  • os.path.dirname(__file__).... 是用来构建路径的。
    • __file__ 是一个内置变量 ,它指向当前脚本文件的路径 (字符串格式)。例如,如果脚本文件的路径是 /home/user/project/scripts/train.py,则 __file__ 就是这个文件的完整路径。
    • os.path.dirname(path) 函数返回给定路径 path 的父目录(即去掉路径中的文件名部分)。例如,如果 __file__/home/user/project/scripts/train.py,那么 os.path.dirname(__file__) 将返回 /home/user/project/scripts,即文件所在的目录。
    • .. 在路径中代表上一级目录 ,是一个相对路径。如果当前路径是 /home/user/project/scripts,则 os.path.join("/home/user/project/scripts", '..') 会返回 /home/user/project,即当前脚本所在目录的父目录。

3. 命令行参数解析 (argparse 模块)

python 复制代码
import argparse

# 创建 ArgumentParser 对象,description 参数会显示在帮助信息中
parser = argparse.ArgumentParser(description='Train EBM model')


# 添加命令行参数,指定名称、默认值、类型、帮助信息
parser.add_argument('--model_type', default="states", type=str,
                    help='choices: states | thetas')
parser.add_argument('--dataset', default='jellyfish', type=str, help='dataset to evaluate')
parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
parser.add_argument('--epochs', default=10, type=int, help='Number of epochs to train the model')

# 解析命令行参数并将其存储在 FLAGS 对象中
FLAGS = parser.parse_args()

# 使用条件判断 if-elseif-else 语句,根据 model_type 选择相应的操作
if FLAGS.model_type == "states":
    ...
elif FLAGS.model_type == "thetas":
    ...

# 输出解析的参数值
print("Dataset:", FLAGS.dataset)
print("Batch size:", FLAGS.batch_size)
print("Epochs:", FLAGS.epochs)

在大多数深度学习训练脚本中,使用 argparse 模块来处理命令行参数

  • parser = argparse.ArgumentParser():创建一个命令行参数解析器 parserparser 是通过 argparse.ArgumentParser() 创建的实例,它负责定义和解析命令行输入。
  • parser.add_argument()添加命令行参数 ,指定名称、类型、默认值及帮助信息。
    • default参数的默认值,如果命令行未提供该参数,则使用默认值。
    • type:定义参数的数据类型
    • help:提供该参数的帮助描述
  • parser.parse_args()从命令行中解析传入的参数 ,并将它们存储在 args 对象中。args 是一个命名空间对象,可以通过点语法访问各个参数。

运行命令行的不同示例:

  • 不传递任何参数,使用默认值

    bash 复制代码
     # 命令行
     python script.py

    输出:

    python 复制代码
    Dataset: jellyfish
    Batch size: 4
    Epochs: 10
  • 传递自定义的命令行参数

    bash 复制代码
    # 命令行
    python script.py --dataset "fish" --batch_size 8 --epochs 20

    输出:

    python 复制代码
    Dataset: fish
    Batch size: 8
    Epochs: 20
  • 查看帮助信息

    bash 复制代码
    python script.py --help

    输出:

    python 复制代码
    usage: script.py [-h] [--dataset DATASET] [--batch_size BATCH_SIZE] [--epochs EPOCHS]
    Train EBM model
    optional arguments:
     -h, --help                   show this help message and exit
     --dataset DATASET            Dataset to evaluate
     --batch_size BATCH_SIZE      Batch size for training
     --epochs EPOCHS              Number of epochs to train the model

4. 字符串格式化

python 复制代码
print("Saved at: ", results_path)
print("DATA_PATH: ", DATA_PATH)
print("number of parameters in model: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
  • print():用于输出信息。
  • 字符串连接:print 语句中使用逗号分隔多个变量 ,可以连接字符串和变量

5. 获取设备

python 复制代码
def get_device():
    return torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

args.device = get_device()

在深度学习任务中,通常需要指定训练的设备,如果你的机器有支持的 GPU,cuda 可以加速模型训练。如果没有 GPU,则会使用 CPU。

get_device() 这个函数会检查当前是否有可用的 GPU

  • torch.device("cuda:4") 中的 cuda:4 表示选择第 5 个 GPU。如果你的机器没有 5 个以上的 GPU,可能会报错。你可以使用 cudacuda:0(通常默认使用第一个 GPU)。
  • torch.cuda.is_available() 用来检查当前是否有可用的 GPU,如果没有返回 False,则会回退到 CPU。

args.device = get_device() 调用 get_device() 函数,将返回的设备对象(cuda:4cpu)赋值给 args.deviceargs.device 之后可以用来指定模型的设备,确保模型训练或推理时使用正确的硬件资源。

上述代码可修改为:

python 复制代码
def get_device():
    if torch.cuda.is_available():
        # 如果有多个 GPU,选择第一个 GPU,避免 'cuda:4' 报错
        return torch.device("cuda:0")  
    else:
        return torch.device("cpu")

args.device = get_device()

6. 加载数据

python 复制代码
def cycle(dl):
    while True:
        for data in dl:
            yield data

无限循环地返回数据加载器中的数据。

7.设置推理结果的子路径

python 复制代码
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
args.inference_result_subpath = os.path.join(
    args.inference_result_path,
    current_time +  "_coeff_ratio_w_{}_J{}_".format(args.coeff_ratio_w, args.coeff_ratio_J)
)

这段代码构建推理结果保存的子目录路径 inference_result_subpath,这个路径将包含当前时间戳 以及与 coeff_ratio_wcoeff_ratio_J 参数相关的信息。

  • 使用 datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 获取当前的时间戳,格式为 年-月-日_时-分-秒
  • 使用 args.coeff_ratio_wargs.coeff_ratio_J 格式化字符串,这两个参数的值会被嵌入路径中 。这些参数可能是用户在命令行中设置的值,影响模型训练或推理过程中的某些超参数
  • 最终路径会类似于:inference_result_path/2024-11-17_12-30-45_coeff_ratio_w_0.3_J0.5_

8. main() 脚本入口点

python 复制代码
# inference.py
def main(args):
    diffusion = load_model(args)
    dataloader = load_data(args)
    inference(dataloader, diffusion, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='xxx model')
	...
	args = parser.parse_args()
	main(args)  # # 调用 main 函数

main 函数通常是脚本的核心逻辑部分,处理程序的主要任务。在 main() 函数中,可以使用 args 对象中的参数值来决定程序的行为,比如设置模型的超参数、选择训练模式等。

这段代码展示了 main 函数的核心逻辑,主要包括以下几个步骤:

  1. 加载模型diffusion = load_model(args)):这行代码通过调用 load_model 函数加载模型及相关组件。args 参数传递了外部配置,可能包含模型路径、超参数等信息。load_model 函数返回 diffusion 扩散模型对象。

  2. 加载数据dataloader = load_data(args)):这行代码调用 load_data 函数加载数据 。args 中包含有关数据的路径或其他配置信息。dataloader 通常是一个生成数据批次(batch)的迭代器,可以用于训练或推理过程。

  3. 推理inference(dataloader, diffusion, args))在这行代码中,inference 函数接受输入数据的迭代器、扩散模型和配置信息 参数,inference 函数通常用于执行模型的推理过程,生成预测或结果

  • if __name__ == '__main__' 这一行的作用是确保当这个脚本作为主程序执行时,main() 函数会被调用。
  • main(args) 这行代码调用 main() 函数 ,并将 args 对象作为参数传递给它

因此,如果这个脚本是作为独立的 Python 文件运行的,它会执行 main 函数

当你运行命令 python inference.py 时:

  • Python 解释器会执行整个 inference.py 文件
  • 脚本中的 if __name__ == '__main__': 块会被触发。
  • 然后会调用 main(args) 函数,并传入通过 argparse 解析的命令行参数 args。

这意味着,main 函数将会被执行,在 main 中使用命令行参数 args 来配置模型、加载数据、进行推理等

二、类相关

要想使用类,先使用 class 关键字去定义类 ,而如何使用类在类方法中定义 ,然后再用实例化对象调用类方法

  • 有个特殊参数 每个类都有,它叫 self,它代表类的当前实例 ,使得方法可以访问和修改实例的属性

  • 有个特殊的类方法 每个类都会定义,它叫 __init__初始化方法 或构建方法,它会在这个类创建实例化对象自动被调用 ,并且传入实例化时的参数 。通常在该方法中对实例属性进行初始化

python 复制代码
# 定义一个类
class Dog:
    def __init__(self, name, age):  # 初始化方法
        self.name = name  # 初始化属性
        self.age = age

    def speak(self):  # 类方法
        print(f"{self.name} says woof!")

# 实例化一个对象
dog1 = Dog("Buddy", 3)  # 创建 Dog 类的实例,__init__ 方法自动执行

# 调用类方法
dog1.speak()  # 输出: Buddy says woof!
  • Dog 是类,定义了 __init__ 方法来初始化 name 和 age 属性。
  • speak() 是一个类方法,它输出 name 属性的内容。
  • 通过 dog1 = Dog("Buddy", 3) 创建了一个 Dog 类的实例,__init__ 方法自动被调用,初始化了 dog1 的属性。
  • dog1.speak() 实例化对象 dog1 调用类方法,使用类。

1. 类的定义及初始化

__init__ 是 Python 中的类初始化方法 ,也叫构造方法 。它用于在类的对象被创建时初始化对象的状态(即设置对象的属性)

__init__ 方法会在类的实例化时自动调用,并且在对象创建后执行

python 复制代码
class ClassName:
    def __init__(self, parameters):
        # 初始化属性或执行其他必要的操作
        self.attribute = value
        # 其他代码
  • __init__(self, parameters)__init__ 方法接受至少一个参数,通常是 self ,它表示类的实例对象 。parameters 是该方法接受的其他参数 ,用于在初始化时传递值
  • self.attribute:self 表示当前对象实例,可以通过它访问类的属性和方法
  • value 是初始化时为属性赋的值,可以是常量、变量或通过其他逻辑生成的值。

示例 1:基础初始化方法

python 复制代码
class Person:
    def __init__(self, name, age):
        # 初始化时将名字和年龄赋值给对象的属性
        self.name = name
        self.age = age

    def introduce(self):
        print(f"Hello, my name is {self.name} and I am {self.age} years old.")

在这个例子中,Person 类的 __init__ 方法接受两个其他参数:name 和 age,并将它们赋值给实例对象的属性 self.nameself.age

python 复制代码
# 创建对象时,初始化属性
person1 = Person("Alice", 30)
person2 = Person("Bob", 25)

# 调用方法
person1.introduce()  # 输出: Hello, my name is Alice and I am 30 years old.
person2.introduce()  # 输出: Hello, my name is Bob and I am 25 years old.

示例 2:使用默认参数

python 复制代码
class Car:
    def __init__(self, make, model, year=2020):
        # 初始化时,year 如果未传递将默认为 2020
        self.make = make
        self.model = model
        self.year = year

    def display_info(self):
        print(f"{self.year} {self.make} {self.model}")

在这个例子中,year 参数具有默认值 2020。如果在创建 Car 实例时未传递 year 参数,它会自动使用默认值 2020。

python 复制代码
# 创建时传递所有参数
car1 = Car("Toyota", "Camry", 2021)

# 创建时只传递 make 和 model,year 会使用默认值 2020
car2 = Car("Honda", "Civic")

# 输出汽车信息
car1.display_info()  # 输出: 2021 Toyota Camry
car2.display_info()  # 输出: 2020 Honda Civic

2. 类的实例化及函数调用

python 复制代码
# train_script.py

# Unet3D_with_Conv3D、GaussianDiffusion、Trainer 是从模块中导入的类
from model.video_diffusion_pytorch.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from diffusion.diffusion_2d_jellyfish import GaussianDiffusion, Trainer

if __name__ == "__main__":
	# 解析命令行参数并将其存储在 FLAGS 对象中
	FLAGS = parser.parse_args()

	# 创建 Unet3D_with_Conv3D 模型实例
	model = Unet3D_with_Conv3D(
		dim = 64,  # 设置模型的基础维度大小
		out_dim = 1 if FLAGS.only_vis_pressure else 3,  # 根据命令行参数 only_vis_pressure 决定输出维度
		dim_mults = (1, 2, 4),  # 传递一个元组作为参数,用于指定每个网络层维度的倍数
		channels=5 if FLAGS.only_vis_pressure else 7  # 根据命令行参数 only_vis_pressure 决定通道数
		)
        
	# 创建 GaussianDiffusion 实例
	diffusion = GaussianDiffusion(
		model,
        image_size = 64,
        frames=FLAGS.frames,
        cond_steps=FLAGS.cond_steps,
        timesteps = 1000,           # 设置扩散步骤数
        sampling_timesteps = 250,   # 采样步骤数
        loss_type = 'l2',           # 设置损失函数类型:L1 or L2
        objective = "pred_noise",
        device =device              # 模型运行的设备(CPU/GPU)
    )
    
	# 创建 Trainer 类的实例,该类用于管理模型的训练
	trainer = Trainer(
        diffusion,
        FLAGS.dataset,
        FLAGS.dataset_path,
        FLAGS.frames,
        FLAGS.traj_len,
        FLAGS.ts,
        FLAGS.log_path,
        train_batch_size = FLAGS.batch_size,  # 训练的批次大小
        train_lr = 1e-3,                  # 学习率
        train_num_steps = 400000,         # 总训练步数
        gradient_accumulate_every = 1,    # 指定进行梯度累积的次数
        ema_decay = 0.995,                # 用于模型参数的指数移动平均值的衰减因子
        save_and_sample_every = 4000,     # 每 4000 步保存模型和进行采样
        results_path = results_path,
        amp = False,                      # 是否使用混合精度训练
        calculate_fid = False,            # 训练过程中是否计算 fid
        is_testdata = FLAGS.is_testdata,
        only_vis_pressure = FLAGS.only_vis_pressure,
        model_type = FLAGS.model_type
    	)
    
	trainer.train()  # 调用 Trainer 类的 train 方法,启动模型的训练过程

这段代码展示了如何定义和使用深度学习模型的训练流程 ,包括模型定义、模型实例化、训练参数设置 ,以及如何通过面向对象编程实现模块化if __name__ == "__main__" 使得这段代码在直接运行脚本时会执行训练逻辑,而在导入时不会执行,从而提高了代码的复用性和模块化水平。

  • if __name__ == "__main__":它是模块和脚本的运行入口 。该语句下的代码仅在该脚本作为主程序运行时才会被执行

    • __name__ 变量:每个 Python 模块都有一个内置属性 __name__,其值决定了模块是被导入还是直接运行
    • __main__:当一个 Python 文件被直接运行 时,__name__ 的值会被设置为 __main__
    • 导入时的行为:如果该模块被其他脚本导入__name__ 的值是该模块的文件名 (不带路径和 .py 扩展名)。
  • model = ClassName(arguments)创建类的实例 ,通过类构造函数 ClassName 初始化对象 model

  • out_dim = 1 if FLAGS.only_vis_pressure else 3:使用条件表达式(类似三元运算符)来设置输出维度,如果 FLAGS.only_vis_pressure 为真,out_dim 为 1,否则为 3。

示例:

python 复制代码
# main_script.py
if __name__ == "__main__":
    print("This will only run when main_script.py is executed directly.")

运行结果:

  • 如果运行 python main_script.py,将输出 This will only run when main_script.py is executed directly.
  • 如果 main_script.py 被其他脚本导入,如 import main_script,这行代码不会被执行。

三、神经网络常用类

1. 工具类

1.1 正弦位置编码类

常见于自然语言处理(如 Transformer)和其他序列建模 任务中。正弦位置编码是一种通过正弦和余弦函数表示位置 的方式,使得模型能够感知输入数据中元素的顺序

python 复制代码
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
  • dim 是一个超参数,表示生成的位置编码的维度(即输出编码的总维度)。通常,这个维度应当是偶数,因为在后续计算中使用了正弦和余弦函数。
  • forward 方法 定义了正弦位置编码的计算过程
  • half_dim = self.dim // 2:将 dim 除以 2,计算出编码的一半维度,这将用于后续的正弦和余弦计算。
  • emb = math.log(10000) / (half_dim - 1):计算一个常数,它与位置编码的尺度有关。这里的 10000 是一个经验值,常用于位置编码的计算。
  • torch.exp(torch.arange(half_dim, device=device) * -emb):生成一个递减的指数序列
    torch.arange(half_dim) 生成一个从 0 到 half_dim-1 的整数序列 。然后通过乘以 -emb,再取指数,得到一组缩放因子,这些因子将用于后续的正弦和余弦函数。
  • x[:, None]:通过切片操作将 x 的维度从 [batch_size] 扩展到 [batch_size, 1],将其广播到与 emb 相乘。这意味着每个位置 x 的值将与 emb 中的每个尺度因子相乘,生成一个位置的尺度序列
    emb[None, :]:emb 被扩展到 [1, half_dim],然后与 x[:, None] 进行广播相乘,生成一个包含每个位置的编码因子
  • emb.sin()emb.cos():对每个位置的编码因子 ,分别计算正弦和余弦值。
    torch.cat((emb.sin(), emb.cos()), dim=-1):将计算得到的正弦和余弦值沿最后一个维度(dim=-1)连接 起来,形成最终的编码。这将使得每个位置的编码有两倍的维度(half_dim 为每种类型,正弦和余弦各一半)。
  • return emb:返回生成的位置编码张量 emb。

位置编码为输入序列的每个位置生成一个向量 ,使得模型可以感知不同元素的相对位置。位置编码通过正弦和余弦函数的组合来实现,这种设计能够让模型在不同的尺度上感知位置信息,且不依赖于具体的训练数据。

该方法的关键优势是它为每个位置生成了一个独特的编码 ,不同的位置信息通过正弦和余弦函数映射到不同的维度上 ,能够捕捉到位置之间的相对关系

RandomOrLearnedSinusoidalPosEmb:可以选择使用随机的或学习的正弦位置编码。

python 复制代码
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim, is_random=False):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad=not is_random)

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi

1.2 上采样/下采样

  • Upsample:上采样模块,使用最近邻插值方法和卷积层进行特征图的上采样。
python 复制代码
def Upsample(dim, dim_out = None):
    return nn.Sequential(
		# 上采样操作,将输入特征图的尺寸扩大 2 倍,使用最近邻插值进行上采样
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        # 卷积层,输入通道数为 dim,输出通道数为 dim_out(如果 dim_out 未定义,则默认为 dim),卷积核大小为 3x3
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )
  • Downsample :下采样模块,使用 Rearrange 层对特征图进行降维。
python 复制代码
def Downsample(dim, dim_out = None):
    return nn.Sequential(
		# 输入张量的形状是 (batch_size, c, h, w),通过 p1=2 和 p2=2,它将 h 和 w 都按 2 的倍数重排列,
		# 因此,每个空间位置的特征将从 c 维度与 h 和 w 的子块合并,形成 dim * 4 个特征通道。这就相当于一个下采样操作,将空间尺寸减小一半。
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        # 卷积层,用于将通道数从 dim * 4 转换为 dim_out,dim_out 默认为 dim。积核大小为 1x1,通常用于改变通道数而不改变空间尺寸。
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

这两个函数 UpsampleDownsample 定义了数据的上采样和下采样操作。它们对输入数据的空间分辨率通道数进行处理,其工作方式如下:

  1. 上采样 Upsample(dim, dim_out):将空间分辨率放大 2 倍(从 (H, W)(2H, 2W)),并通过卷积调整通道数。
  • nn.Upsample(scale_factor=2, mode='nearest'):

    • 使用最近邻插值法,将输入的空间尺寸(高度和宽度)扩大为原来的 2 倍。
    • 例如,输入张量大小为 (B, C, H, W),则经过这一步后,大小变为 (B, C, 2H, 2W)
  • nn.Conv2d(dim, dim_out, 3, padding=1):

    • 卷积层,使用 3 × 3 3 \times 3 3×3 的卷积核对放大的数据进行处理,调整通道数为 dim_out
    • padding=1 确保空间分辨率保持不变 (仍为 (2H, 2W))。
    • 如果 dim_out 未指定,默认设置为 dim,即输出的通道数与输入一致。

最终,Upsample 的效果是

  • 输入维度 : (B, C, H, W)
  • 输出维度 : (B, dim_out, 2H, 2W)
  1. 下采样Downsample(dim, dim_out):将空间分辨率缩小 2 倍(从 (H, W)(H/2, W/2)),并通过通道重排和卷积调整通道数。
  • Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2):

    • 通过重排张量,将空间维度分块到通道维度。
    • 将每 2 × 2 2 \times 2 2×2 的块合并成新的通道
    • 输入的大小为 (B, C, H, W),输出的大小变为 (B, 4C, H/2, W/2)

    重排细节

    • 空间维度 (H, W) 被分成 H/2W/2 ,每个分块是 2 × 2 2 \times 2 2×2。
    • C 通道被扩展为 4C ,因为 2 × 2 = 4 2 \times 2 = 4 2×2=4。
  • nn.Conv2d(dim * 4, dim_out, 1):

    • 使用 1 × 1 1 \times 1 1×1 的卷积核对重新排列的通道进行压缩,调整通道数为 dim_out
    • 如果 dim_out 未指定,默认值是 dim

最终,Downsample 的效果是

  • 输入维度 : (B, C, H, W)
  • 输出维度 : (B, dim_out, H/2, W/2)

1.3 标准化

  • WeightStandardizedConv2d:实现了加权标准化卷积层,使用加权标准化(weight standardization)来提高卷积层的训练效率。
python 复制代码
class WeightStandardizedConv2d(nn.Conv2d):
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        weight = self.weight
        mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()
        return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

1.4 归一化层

  • LayerNorm:自定义的层归一化,标准化每个通道的特征图。
python 复制代码
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g
  • PreNorm:先进行归一化,再传递给后续函数处理。
python 复制代码
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

1.5 提取函数 extract

python 复制代码
def extract(a, t, x_shape): 
    b, *_ = t.shape  # 从 t 中获取 batch 大小 b,通常是输入数据的第一个维度
    out = a.gather(-1, t)  # 使用 PyTorch 的 gather 函数,从 a 中按时间步索引 t 提取对应的参数
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))  # 将提取出的参数调整为形状 [B, 1, ..., 1],使得与 x_shape 的维度一致,便于后续广播

这段代码的目的是从参数向量 a 中按时间步索引 t 提取对应的值 ,并调整其形状[b, 1, 1, ..., 1],以兼容目标张量 x_shape 的广播操作

这种形式在扩散模型中非常常见,用于将时间相关的参数扩展到整个张量操作

输入参数

  • a:表示一个预先计算好的时间步相关参数(如扩散模型中的 alphas_cumprodsnr),通常是 [T] 的张量。
  • t:表示当前 batch 中每个样本的时间步索引([B] 的整数张量)。
  • x_shape:输入数据的形状,用于对结果进行广播。

难点理解 return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  1. b

    • 表示批量大小,通常是输入数据的第一个维度。
    • 例如,如果 x_shape = (8, 3, 32, 32)(批大小为 8),则 b = 8
  2. len(x_shape) - 1

    • 表示目标张量的维度数减去 1。
    • 例如,对于 x_shape = (8, 3, 32, 32)len(x_shape) = 4,因此 len(x_shape) - 1 = 3
  3. (1,) * (len(x_shape) - 1)

    • 生成一个包含 len(x_shape) - 1 个值为 1 的元组。
    • 例如,如果 len(x_shape) - 1 = 3,则结果是 (1, 1, 1)
  4. out.reshape(b, *((1,) * (len(x_shape) - 1)))

    • out 的形状调整为 [b, 1, 1, 1, ...],以便其与 x_shape 的维度兼容。
    • b 决定第一个维度,其余维度填充 1,为后续广播做准备。

在 PyTorch 中,广播机制会自动扩展维度为 1 的张量,以匹配目标张量的维度。例如:

  • 如果目标张量的形状是 [8, 3, 32, 32],一个形状为 [8, 1, 1, 1] 的张量会被广播为 [8, 3, 32, 32]

【举例说明】

输入数据:

python 复制代码
a = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])  # [2, 3]
t = torch.tensor([0, 2])  # 每个 batch 的时间步索引
x_shape = (2, 3, 32, 32)  # 假设目标张量形状

执行过程:

  1. a.gather(-1, t)

    • gather 根据 ta 的最后一维提取值。a 的形状是 [2, 3],它的最后一维大小为 3
    python 复制代码
    a = [[0.1, 0.2, 0.3],   # 第一批次
         [0.4, 0.5, 0.6]]   # 第二批次
    • 对第一批次(索引 t[0] = 0):从 [0.1, 0.2, 0.3] 提取第 0 个值,得到 0.1
    • 对第二批次(索引 t[1] = 2):从 [0.4, 0.5, 0.6] 提取第 2 个值,得到 0.6
    • 结果:out = [0.1, 0.6],形状为 [2]
  2. (1,) * (len(x_shape) - 1)

    • x_shape = (2, 3, 32, 32),因此 len(x_shape) - 1 = 3
    • 结果:(1, 1, 1)
  3. out.reshape(b, *((1,) * (len(x_shape) - 1)))

    • b = 2,结果形状为 [2, 1, 1, 1]

out(形状为 [2, 1, 1, 1])与目标张量(例如 [2, 3, 32, 32])相乘时:out 会自动广播为 [2, 3, 32, 32]

2. 构建网络块类

2.1 block()

python 复制代码
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

这段代码定义了一个名为 Block 的类,它继承自 nn.Module,通常用于构建神经网络模型中的模块。这个 Block 类包括一个卷积层、归一化层、激活函数,以及可选的 scale_shift 操作。

整体作用

  • Block 类构建了一个标准的深度学习模块,常用于卷积神经网络中的基本单元。
  • 先通过标准化卷积层提取特征 ,再进行分组归一化 ,最后使用 SiLU 激活函数进行非线性变换scale_shift 提供了额外的灵活性,使得模块可以在某些应用中进行动态调整。

初始化方法__init__ 方法用于初始化类的实例

  • 参数:
    • dim: 输入通道的数量。
    • dim_out: 输出通道的数量。
    • groups: 归一化层的分组数,默认为 8。
  • super().__init__() 调用父类 nn.Module 的构造方法。
  • 初始化的组件:
    • self.proj:一个卷积层,使用自定义的 WeightStandardizedConv2d(这是一个标准化权重的卷积层),卷积核大小为 3,填充为 1。
    • self.normnn.GroupNorm 归一化层,使用 groups 参数对输入通道进行分组归一化。
    • self.act:激活函数 nn.SiLU(),一种平滑的激活函数,也称为 Swish 激活函数。

forward 前向方法定义了数据如何通过网络进行传递

  • 参数:

    • x:输入张量。
    • scale_shift:一个可选参数,包含 (scale, shift),用于缩放和平移 x
  • 前向过程:

    • x = self.proj(x)将输入 x 通过卷积层进行卷积操作
    • x = self.norm(x)将卷积后的结果进行分组归一化
    • 检查 scale_shift 是否存在(使用 exists(scale_shift))。如果存在,将 scale_shift 拆分为 scaleshift,其中 scale 是缩放因子,shift 是偏移量 。并应用以下变换:x = x * (scale + 1) + shift,将 x 进行缩放和偏移调整。
    • x = self.act(x):将结果通过激活函数 SiLU 激活。
    • 返回结果 x

2.2 残差连接

Residual :残差连接模块,接受一个函数作为输入, 返回该函数的输出与输入的和

python 复制代码
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

2.3 ResnetBlock

python 复制代码
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)

        return h + self.res_conv(x)

这段代码定义了一个 ResnetBlock 类,它是用于深度学习模型中的残差块,继承自 nn.ModuleResnetBlock 包括两个 Block 层和一个残差连接,并可以在需要时接受时间嵌入(time_emb)用于调节网络行为。

__init__ 方法 初始化组件:

  • self.mlp: 一个 MLP 层,用于将时间嵌入投影到 dim_out * 2 维度。如果 time_emb_dim 存在,则创建此 MLP;否则为 None
  • self.block1self.block2: 两个 Block 实例,分别用于特征提取
  • self.res_conv: 一个 1x1 卷积层(或恒等映射),用于调整输入 x 与输出 h 在维度上的匹配。如果 dimdim_out 不相同,则使用卷积层进行调整。

具体来说,

self.block1 = Block(dim, dim_out, groups = groups)

self.block2 = Block(dim_out, dim_out, groups = groups)

这两行代码实例化了两个 Block 层,并在 ResnetBlock 中分别赋值给 self.block1self.block2。每个 Block 对象的初始化会执行 Block 类的 __init__ 方法中的内容,

  • 当创建 self.block1 时,Block 的 __init__ 方法执行以下步骤:
    self.proj 初始化为一个 WeightStandardizedConv2d 卷积层,它的输入通道数为 dim,输出通道数为 dim_out,卷积核大小为 3,并且带有 1 个像素的填充
    self.norm 初始化为 GroupNorm,用于将输出通道 dim_out 进行分组归一化。
    self.act 初始化为 SiLU 激活函数。
  • 创建 self.block2 时,执行了 Block 的 __init__ 方法,但与 self.block1 的主要区别是:
    这次 proj 卷积层的输入通道和输出通道都是 dim_out ,使得输出通道数保持不变

forward 方法:

  • 接收输入 x 和可选的时间嵌入 time_emb
  • 如果 self.mlp 存在且 time_emb 存在,则将 time_emb 通过 MLP,并通过 rearrange 重塑为适合卷积操作的形状。然后将其切分为 scaleshift(用于 scale_shift 操作)。
  • 通过 block1 进行前向传播,并应用 scale_shift
  • 经过 block2 进一步处理。
  • 最后,返回 h + self.res_conv(x),实现残差连接

2.3 Attention

python 复制代码
class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
        q = q * self.scale
        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

这段代码定义了一个用于图像数据的自注意力机制 模块,Attention,它继承自 nn.Module 并用于深度学习模型中。

  1. __init__ 方法:

    • 接收 dim(输入通道数)、heads(多头注意力头的数量)、dim_head(每个注意力头的维度)作为参数。
    • 计算 hidden_dimdim_head * heads,这是多头注意力中每个 QKV 的总维度。
    • self.scale 用于缩放 Q 向量,以稳定训练。
    • self.to_qkv 是一个 1x1 卷积层,将输入的特征图变换为 Q(查询)、K(键)和 V(值)向量 ,输出通道数是 hidden_dim * 3
    • self.to_out 是另一个 1x1 卷积层,用于将注意力机制的输出映射回输入的维度 dim
  2. forward 方法:

    • 输入 x 是一个四维张量,形状为 (batch_size, channels, height, width)
    • 使用 self.to_qkv(x) 生成 QKV,并通过 chunk(3, dim=1) 将它们分开。
    • 使用 map()rearrange()QKV 重排,使它们适应多头注意力的形状 (batch_size, heads, dim_head, tokens),其中 tokens = height * width
    • Q 向量乘以 self.scale 进行缩放。
    • 计算相似度矩阵 sim,通过 einsum('b h d i, b h d j -> b h i j', q, k) 实现,表示 QK 之间的点积。
    • 使用 softmax 计算注意力权重 attn
    • 计算注意力加权后的输出 out,通过 einsum('b h i j, b h d j -> b h i d', attn, v) 将注意力矩阵应用于 V
    • 重排输出形状回 (batch_size, channels, height, width)
    • 最后,使用 self.to_out(out) 将输出映射回原输入通道数。

该模块用于提取图像特征的自注意力机制,帮助模型在处理复杂输入数据时,捕获长距离依赖和上下文信息。

相关推荐
网易独家音乐人Mike Zhou2 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
安静读书2 小时前
Python解析视频FPS(帧率)、分辨率信息
python·opencv·音视频
小二·4 小时前
java基础面试题笔记(基础篇)
java·笔记·python
一念之坤7 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python
wxl7812277 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
NoneCoder7 小时前
Python入门(12)--数据处理
开发语言·python
LKID体8 小时前
Python操作neo4j库py2neo使用(一)
python·oracle·neo4j
小尤笔记8 小时前
利用Python编写简单登录系统
开发语言·python·数据分析·python基础
FreedomLeo18 小时前
Python数据分析NumPy和pandas(四十、Python 中的建模库statsmodels 和 scikit-learn)
python·机器学习·数据分析·scikit-learn·statsmodels·numpy和pandas