Python 神经网络项目常用语法

  • [一、 工具](#一、 工具)
    • [1. 导入模块和包](#1. 导入模块和包)
    • [2. 修改系统路径 (sys.path.append)](#2. 修改系统路径 (sys.path.append))
    • [3. 命令行参数解析 (argparse 模块)](#3. 命令行参数解析 (argparse 模块))
    • [4. 字符串格式化](#4. 字符串格式化)
    • [5. 获取设备](#5. 获取设备)
    • [6. 加载数据](#6. 加载数据)
    • 7.设置推理结果的子路径
    • [8. main() 脚本入口点](#8. main() 脚本入口点)
  • 二、类相关
    • [1. 类的定义及初始化](#1. 类的定义及初始化)
    • [2. 类的实例化及函数调用](#2. 类的实例化及函数调用)
  • 三、神经网络常用类
    • [1. 工具类](#1. 工具类)
      • [1.1 正弦位置编码类](#1.1 正弦位置编码类)
      • [1.2 上采样/下采样](#1.2 上采样/下采样)
      • [1.3 标准化](#1.3 标准化)
      • [1.4 归一化层](#1.4 归一化层)
      • [1.5 提取函数 extract](#1.5 提取函数 extract)
    • [2. 构建网络块类](#2. 构建网络块类)
      • [2.1 block()](#2.1 block())
      • [2.2 残差连接](#2.2 残差连接)
      • [2.3 ResnetBlock](#2.3 ResnetBlock)
      • [2.3 Attention](#2.3 Attention)

一、 工具

1. 导入模块和包

python 复制代码
import os
import argparse
import sys as s
from accelerate import Accelerator
  • import:用于导入模块和包,可以选择导入单个模块、多个模块。
  • from ... import ...从特定模块中导入具体的类、函数或变量。
  • impoert ... as ...:可以为导入的模块指定别名,使代码简洁。

2. 修改系统路径 (sys.path.append)

python 复制代码
# 返回当前脚本所在目录的父目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 

# 返回当前脚本所在目录的上上级目录
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

这两行代码通过 os.path.dirname(__file__) 获取当前脚本所在的目录,然后使用 .. 依次返回上一级目录。

  • sys.path.append()修改模块搜索路径 ,可以动态添加额外的搜索路径,以便访问其他目录中的模块。
  • os.path.join(a, b, ...)将多个路径部分连接成一个完整的路径 ,并确保使用正确的路径分隔符 (在 UNIX 系统上是 /,在 Windows 上是 \)。
    • 例如,os.path.join("/home/user", "project") 将返回 /home/user/project
  • os.path.dirname(__file__).... 是用来构建路径的。
    • __file__ 是一个内置变量 ,它指向当前脚本文件的路径 (字符串格式)。例如,如果脚本文件的路径是 /home/user/project/scripts/train.py,则 __file__ 就是这个文件的完整路径。
    • os.path.dirname(path) 函数返回给定路径 path 的父目录(即去掉路径中的文件名部分)。例如,如果 __file__/home/user/project/scripts/train.py,那么 os.path.dirname(__file__) 将返回 /home/user/project/scripts,即文件所在的目录。
    • .. 在路径中代表上一级目录 ,是一个相对路径。如果当前路径是 /home/user/project/scripts,则 os.path.join("/home/user/project/scripts", '..') 会返回 /home/user/project,即当前脚本所在目录的父目录。

3. 命令行参数解析 (argparse 模块)

python 复制代码
import argparse

# 创建 ArgumentParser 对象,description 参数会显示在帮助信息中
parser = argparse.ArgumentParser(description='Train EBM model')


# 添加命令行参数,指定名称、默认值、类型、帮助信息
parser.add_argument('--model_type', default="states", type=str,
                    help='choices: states | thetas')
parser.add_argument('--dataset', default='jellyfish', type=str, help='dataset to evaluate')
parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
parser.add_argument('--epochs', default=10, type=int, help='Number of epochs to train the model')

# 解析命令行参数并将其存储在 FLAGS 对象中
FLAGS = parser.parse_args()

# 使用条件判断 if-elseif-else 语句,根据 model_type 选择相应的操作
if FLAGS.model_type == "states":
    ...
elif FLAGS.model_type == "thetas":
    ...

# 输出解析的参数值
print("Dataset:", FLAGS.dataset)
print("Batch size:", FLAGS.batch_size)
print("Epochs:", FLAGS.epochs)

在大多数深度学习训练脚本中,使用 argparse 模块来处理命令行参数

  • parser = argparse.ArgumentParser():创建一个命令行参数解析器 parserparser 是通过 argparse.ArgumentParser() 创建的实例,它负责定义和解析命令行输入。
  • parser.add_argument()添加命令行参数 ,指定名称、类型、默认值及帮助信息。
    • default参数的默认值,如果命令行未提供该参数,则使用默认值。
    • type:定义参数的数据类型
    • help:提供该参数的帮助描述
  • parser.parse_args()从命令行中解析传入的参数 ,并将它们存储在 args 对象中。args 是一个命名空间对象,可以通过点语法访问各个参数。

运行命令行的不同示例:

  • 不传递任何参数,使用默认值

    bash 复制代码
     # 命令行
     python script.py

    输出:

    python 复制代码
    Dataset: jellyfish
    Batch size: 4
    Epochs: 10
  • 传递自定义的命令行参数

    bash 复制代码
    # 命令行
    python script.py --dataset "fish" --batch_size 8 --epochs 20

    输出:

    python 复制代码
    Dataset: fish
    Batch size: 8
    Epochs: 20
  • 查看帮助信息

    bash 复制代码
    python script.py --help

    输出:

    python 复制代码
    usage: script.py [-h] [--dataset DATASET] [--batch_size BATCH_SIZE] [--epochs EPOCHS]
    Train EBM model
    optional arguments:
     -h, --help                   show this help message and exit
     --dataset DATASET            Dataset to evaluate
     --batch_size BATCH_SIZE      Batch size for training
     --epochs EPOCHS              Number of epochs to train the model

4. 字符串格式化

python 复制代码
print("Saved at: ", results_path)
print("DATA_PATH: ", DATA_PATH)
print("number of parameters in model: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
  • print():用于输出信息。
  • 字符串连接:print 语句中使用逗号分隔多个变量 ,可以连接字符串和变量

5. 获取设备

python 复制代码
def get_device():
    return torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

args.device = get_device()

在深度学习任务中,通常需要指定训练的设备,如果你的机器有支持的 GPU,cuda 可以加速模型训练。如果没有 GPU,则会使用 CPU。

get_device() 这个函数会检查当前是否有可用的 GPU

  • torch.device("cuda:4") 中的 cuda:4 表示选择第 5 个 GPU。如果你的机器没有 5 个以上的 GPU,可能会报错。你可以使用 cudacuda:0(通常默认使用第一个 GPU)。
  • torch.cuda.is_available() 用来检查当前是否有可用的 GPU,如果没有返回 False,则会回退到 CPU。

args.device = get_device() 调用 get_device() 函数,将返回的设备对象(cuda:4cpu)赋值给 args.deviceargs.device 之后可以用来指定模型的设备,确保模型训练或推理时使用正确的硬件资源。

上述代码可修改为:

python 复制代码
def get_device():
    if torch.cuda.is_available():
        # 如果有多个 GPU,选择第一个 GPU,避免 'cuda:4' 报错
        return torch.device("cuda:0")  
    else:
        return torch.device("cpu")

args.device = get_device()

6. 加载数据

python 复制代码
def cycle(dl):
    while True:
        for data in dl:
            yield data

无限循环地返回数据加载器中的数据。

7.设置推理结果的子路径

python 复制代码
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
args.inference_result_subpath = os.path.join(
    args.inference_result_path,
    current_time +  "_coeff_ratio_w_{}_J{}_".format(args.coeff_ratio_w, args.coeff_ratio_J)
)

这段代码构建推理结果保存的子目录路径 inference_result_subpath,这个路径将包含当前时间戳 以及与 coeff_ratio_wcoeff_ratio_J 参数相关的信息。

  • 使用 datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 获取当前的时间戳,格式为 年-月-日_时-分-秒
  • 使用 args.coeff_ratio_wargs.coeff_ratio_J 格式化字符串,这两个参数的值会被嵌入路径中 。这些参数可能是用户在命令行中设置的值,影响模型训练或推理过程中的某些超参数
  • 最终路径会类似于:inference_result_path/2024-11-17_12-30-45_coeff_ratio_w_0.3_J0.5_

8. main() 脚本入口点

python 复制代码
# inference.py
def main(args):
    diffusion = load_model(args)
    dataloader = load_data(args)
    inference(dataloader, diffusion, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='xxx model')
	...
	args = parser.parse_args()
	main(args)  # # 调用 main 函数

main 函数通常是脚本的核心逻辑部分,处理程序的主要任务。在 main() 函数中,可以使用 args 对象中的参数值来决定程序的行为,比如设置模型的超参数、选择训练模式等。

这段代码展示了 main 函数的核心逻辑,主要包括以下几个步骤:

  1. 加载模型diffusion = load_model(args)):这行代码通过调用 load_model 函数加载模型及相关组件。args 参数传递了外部配置,可能包含模型路径、超参数等信息。load_model 函数返回 diffusion 扩散模型对象。

  2. 加载数据dataloader = load_data(args)):这行代码调用 load_data 函数加载数据 。args 中包含有关数据的路径或其他配置信息。dataloader 通常是一个生成数据批次(batch)的迭代器,可以用于训练或推理过程。

  3. 推理inference(dataloader, diffusion, args))在这行代码中,inference 函数接受输入数据的迭代器、扩散模型和配置信息 参数,inference 函数通常用于执行模型的推理过程,生成预测或结果

  • if __name__ == '__main__' 这一行的作用是确保当这个脚本作为主程序执行时,main() 函数会被调用。
  • main(args) 这行代码调用 main() 函数 ,并将 args 对象作为参数传递给它

因此,如果这个脚本是作为独立的 Python 文件运行的,它会执行 main 函数

当你运行命令 python inference.py 时:

  • Python 解释器会执行整个 inference.py 文件
  • 脚本中的 if __name__ == '__main__': 块会被触发。
  • 然后会调用 main(args) 函数,并传入通过 argparse 解析的命令行参数 args。

这意味着,main 函数将会被执行,在 main 中使用命令行参数 args 来配置模型、加载数据、进行推理等

二、类相关

要想使用类,先使用 class 关键字去定义类 ,而如何使用类在类方法中定义 ,然后再用实例化对象调用类方法

  • 有个特殊参数 每个类都有,它叫 self,它代表类的当前实例 ,使得方法可以访问和修改实例的属性

  • 有个特殊的类方法 每个类都会定义,它叫 __init__初始化方法 或构建方法,它会在这个类创建实例化对象自动被调用 ,并且传入实例化时的参数 。通常在该方法中对实例属性进行初始化

python 复制代码
# 定义一个类
class Dog:
    def __init__(self, name, age):  # 初始化方法
        self.name = name  # 初始化属性
        self.age = age

    def speak(self):  # 类方法
        print(f"{self.name} says woof!")

# 实例化一个对象
dog1 = Dog("Buddy", 3)  # 创建 Dog 类的实例,__init__ 方法自动执行

# 调用类方法
dog1.speak()  # 输出: Buddy says woof!
  • Dog 是类,定义了 __init__ 方法来初始化 name 和 age 属性。
  • speak() 是一个类方法,它输出 name 属性的内容。
  • 通过 dog1 = Dog("Buddy", 3) 创建了一个 Dog 类的实例,__init__ 方法自动被调用,初始化了 dog1 的属性。
  • dog1.speak() 实例化对象 dog1 调用类方法,使用类。

1. 类的定义及初始化

__init__ 是 Python 中的类初始化方法 ,也叫构造方法 。它用于在类的对象被创建时初始化对象的状态(即设置对象的属性)

__init__ 方法会在类的实例化时自动调用,并且在对象创建后执行

python 复制代码
class ClassName:
    def __init__(self, parameters):
        # 初始化属性或执行其他必要的操作
        self.attribute = value
        # 其他代码
  • __init__(self, parameters)__init__ 方法接受至少一个参数,通常是 self ,它表示类的实例对象 。parameters 是该方法接受的其他参数 ,用于在初始化时传递值
  • self.attribute:self 表示当前对象实例,可以通过它访问类的属性和方法
  • value 是初始化时为属性赋的值,可以是常量、变量或通过其他逻辑生成的值。

示例 1:基础初始化方法

python 复制代码
class Person:
    def __init__(self, name, age):
        # 初始化时将名字和年龄赋值给对象的属性
        self.name = name
        self.age = age

    def introduce(self):
        print(f"Hello, my name is {self.name} and I am {self.age} years old.")

在这个例子中,Person 类的 __init__ 方法接受两个其他参数:name 和 age,并将它们赋值给实例对象的属性 self.nameself.age

python 复制代码
# 创建对象时,初始化属性
person1 = Person("Alice", 30)
person2 = Person("Bob", 25)

# 调用方法
person1.introduce()  # 输出: Hello, my name is Alice and I am 30 years old.
person2.introduce()  # 输出: Hello, my name is Bob and I am 25 years old.

示例 2:使用默认参数

python 复制代码
class Car:
    def __init__(self, make, model, year=2020):
        # 初始化时,year 如果未传递将默认为 2020
        self.make = make
        self.model = model
        self.year = year

    def display_info(self):
        print(f"{self.year} {self.make} {self.model}")

在这个例子中,year 参数具有默认值 2020。如果在创建 Car 实例时未传递 year 参数,它会自动使用默认值 2020。

python 复制代码
# 创建时传递所有参数
car1 = Car("Toyota", "Camry", 2021)

# 创建时只传递 make 和 model,year 会使用默认值 2020
car2 = Car("Honda", "Civic")

# 输出汽车信息
car1.display_info()  # 输出: 2021 Toyota Camry
car2.display_info()  # 输出: 2020 Honda Civic

2. 类的实例化及函数调用

python 复制代码
# train_script.py

# Unet3D_with_Conv3D、GaussianDiffusion、Trainer 是从模块中导入的类
from model.video_diffusion_pytorch.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D
from diffusion.diffusion_2d_jellyfish import GaussianDiffusion, Trainer

if __name__ == "__main__":
	# 解析命令行参数并将其存储在 FLAGS 对象中
	FLAGS = parser.parse_args()

	# 创建 Unet3D_with_Conv3D 模型实例
	model = Unet3D_with_Conv3D(
		dim = 64,  # 设置模型的基础维度大小
		out_dim = 1 if FLAGS.only_vis_pressure else 3,  # 根据命令行参数 only_vis_pressure 决定输出维度
		dim_mults = (1, 2, 4),  # 传递一个元组作为参数,用于指定每个网络层维度的倍数
		channels=5 if FLAGS.only_vis_pressure else 7  # 根据命令行参数 only_vis_pressure 决定通道数
		)
        
	# 创建 GaussianDiffusion 实例
	diffusion = GaussianDiffusion(
		model,
        image_size = 64,
        frames=FLAGS.frames,
        cond_steps=FLAGS.cond_steps,
        timesteps = 1000,           # 设置扩散步骤数
        sampling_timesteps = 250,   # 采样步骤数
        loss_type = 'l2',           # 设置损失函数类型:L1 or L2
        objective = "pred_noise",
        device =device              # 模型运行的设备(CPU/GPU)
    )
    
	# 创建 Trainer 类的实例,该类用于管理模型的训练
	trainer = Trainer(
        diffusion,
        FLAGS.dataset,
        FLAGS.dataset_path,
        FLAGS.frames,
        FLAGS.traj_len,
        FLAGS.ts,
        FLAGS.log_path,
        train_batch_size = FLAGS.batch_size,  # 训练的批次大小
        train_lr = 1e-3,                  # 学习率
        train_num_steps = 400000,         # 总训练步数
        gradient_accumulate_every = 1,    # 指定进行梯度累积的次数
        ema_decay = 0.995,                # 用于模型参数的指数移动平均值的衰减因子
        save_and_sample_every = 4000,     # 每 4000 步保存模型和进行采样
        results_path = results_path,
        amp = False,                      # 是否使用混合精度训练
        calculate_fid = False,            # 训练过程中是否计算 fid
        is_testdata = FLAGS.is_testdata,
        only_vis_pressure = FLAGS.only_vis_pressure,
        model_type = FLAGS.model_type
    	)
    
	trainer.train()  # 调用 Trainer 类的 train 方法,启动模型的训练过程

这段代码展示了如何定义和使用深度学习模型的训练流程 ,包括模型定义、模型实例化、训练参数设置 ,以及如何通过面向对象编程实现模块化if __name__ == "__main__" 使得这段代码在直接运行脚本时会执行训练逻辑,而在导入时不会执行,从而提高了代码的复用性和模块化水平。

  • if __name__ == "__main__":它是模块和脚本的运行入口 。该语句下的代码仅在该脚本作为主程序运行时才会被执行

    • __name__ 变量:每个 Python 模块都有一个内置属性 __name__,其值决定了模块是被导入还是直接运行
    • __main__:当一个 Python 文件被直接运行 时,__name__ 的值会被设置为 __main__
    • 导入时的行为:如果该模块被其他脚本导入__name__ 的值是该模块的文件名 (不带路径和 .py 扩展名)。
  • model = ClassName(arguments)创建类的实例 ,通过类构造函数 ClassName 初始化对象 model

  • out_dim = 1 if FLAGS.only_vis_pressure else 3:使用条件表达式(类似三元运算符)来设置输出维度,如果 FLAGS.only_vis_pressure 为真,out_dim 为 1,否则为 3。

示例:

python 复制代码
# main_script.py
if __name__ == "__main__":
    print("This will only run when main_script.py is executed directly.")

运行结果:

  • 如果运行 python main_script.py,将输出 This will only run when main_script.py is executed directly.
  • 如果 main_script.py 被其他脚本导入,如 import main_script,这行代码不会被执行。

三、神经网络常用类

1. 工具类

1.1 正弦位置编码类

常见于自然语言处理(如 Transformer)和其他序列建模 任务中。正弦位置编码是一种通过正弦和余弦函数表示位置 的方式,使得模型能够感知输入数据中元素的顺序

python 复制代码
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
  • dim 是一个超参数,表示生成的位置编码的维度(即输出编码的总维度)。通常,这个维度应当是偶数,因为在后续计算中使用了正弦和余弦函数。
  • forward 方法 定义了正弦位置编码的计算过程
  • half_dim = self.dim // 2:将 dim 除以 2,计算出编码的一半维度,这将用于后续的正弦和余弦计算。
  • emb = math.log(10000) / (half_dim - 1):计算一个常数,它与位置编码的尺度有关。这里的 10000 是一个经验值,常用于位置编码的计算。
  • torch.exp(torch.arange(half_dim, device=device) * -emb):生成一个递减的指数序列
    torch.arange(half_dim) 生成一个从 0 到 half_dim-1 的整数序列 。然后通过乘以 -emb,再取指数,得到一组缩放因子,这些因子将用于后续的正弦和余弦函数。
  • x[:, None]:通过切片操作将 x 的维度从 [batch_size] 扩展到 [batch_size, 1],将其广播到与 emb 相乘。这意味着每个位置 x 的值将与 emb 中的每个尺度因子相乘,生成一个位置的尺度序列
    emb[None, :]:emb 被扩展到 [1, half_dim],然后与 x[:, None] 进行广播相乘,生成一个包含每个位置的编码因子
  • emb.sin()emb.cos():对每个位置的编码因子 ,分别计算正弦和余弦值。
    torch.cat((emb.sin(), emb.cos()), dim=-1):将计算得到的正弦和余弦值沿最后一个维度(dim=-1)连接 起来,形成最终的编码。这将使得每个位置的编码有两倍的维度(half_dim 为每种类型,正弦和余弦各一半)。
  • return emb:返回生成的位置编码张量 emb。

位置编码为输入序列的每个位置生成一个向量 ,使得模型可以感知不同元素的相对位置。位置编码通过正弦和余弦函数的组合来实现,这种设计能够让模型在不同的尺度上感知位置信息,且不依赖于具体的训练数据。

该方法的关键优势是它为每个位置生成了一个独特的编码 ,不同的位置信息通过正弦和余弦函数映射到不同的维度上 ,能够捕捉到位置之间的相对关系

RandomOrLearnedSinusoidalPosEmb:可以选择使用随机的或学习的正弦位置编码。

python 复制代码
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim, is_random=False):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad=not is_random)

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi

1.2 上采样/下采样

  • Upsample:上采样模块,使用最近邻插值方法和卷积层进行特征图的上采样。
python 复制代码
def Upsample(dim, dim_out = None):
    return nn.Sequential(
		# 上采样操作,将输入特征图的尺寸扩大 2 倍,使用最近邻插值进行上采样
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        # 卷积层,输入通道数为 dim,输出通道数为 dim_out(如果 dim_out 未定义,则默认为 dim),卷积核大小为 3x3
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )
  • Downsample :下采样模块,使用 Rearrange 层对特征图进行降维。
python 复制代码
def Downsample(dim, dim_out = None):
    return nn.Sequential(
		# 输入张量的形状是 (batch_size, c, h, w),通过 p1=2 和 p2=2,它将 h 和 w 都按 2 的倍数重排列,
		# 因此,每个空间位置的特征将从 c 维度与 h 和 w 的子块合并,形成 dim * 4 个特征通道。这就相当于一个下采样操作,将空间尺寸减小一半。
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        # 卷积层,用于将通道数从 dim * 4 转换为 dim_out,dim_out 默认为 dim。积核大小为 1x1,通常用于改变通道数而不改变空间尺寸。
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

这两个函数 UpsampleDownsample 定义了数据的上采样和下采样操作。它们对输入数据的空间分辨率通道数进行处理,其工作方式如下:

  1. 上采样 Upsample(dim, dim_out):将空间分辨率放大 2 倍(从 (H, W)(2H, 2W)),并通过卷积调整通道数。
  • nn.Upsample(scale_factor=2, mode='nearest'):

    • 使用最近邻插值法,将输入的空间尺寸(高度和宽度)扩大为原来的 2 倍。
    • 例如,输入张量大小为 (B, C, H, W),则经过这一步后,大小变为 (B, C, 2H, 2W)
  • nn.Conv2d(dim, dim_out, 3, padding=1):

    • 卷积层,使用 3 × 3 3 \times 3 3×3 的卷积核对放大的数据进行处理,调整通道数为 dim_out
    • padding=1 确保空间分辨率保持不变 (仍为 (2H, 2W))。
    • 如果 dim_out 未指定,默认设置为 dim,即输出的通道数与输入一致。

最终,Upsample 的效果是

  • 输入维度 : (B, C, H, W)
  • 输出维度 : (B, dim_out, 2H, 2W)
  1. 下采样Downsample(dim, dim_out):将空间分辨率缩小 2 倍(从 (H, W)(H/2, W/2)),并通过通道重排和卷积调整通道数。
  • Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2):

    • 通过重排张量,将空间维度分块到通道维度。
    • 将每 2 × 2 2 \times 2 2×2 的块合并成新的通道
    • 输入的大小为 (B, C, H, W),输出的大小变为 (B, 4C, H/2, W/2)

    重排细节

    • 空间维度 (H, W) 被分成 H/2W/2 ,每个分块是 2 × 2 2 \times 2 2×2。
    • C 通道被扩展为 4C ,因为 2 × 2 = 4 2 \times 2 = 4 2×2=4。
  • nn.Conv2d(dim * 4, dim_out, 1):

    • 使用 1 × 1 1 \times 1 1×1 的卷积核对重新排列的通道进行压缩,调整通道数为 dim_out
    • 如果 dim_out 未指定,默认值是 dim

最终,Downsample 的效果是

  • 输入维度 : (B, C, H, W)
  • 输出维度 : (B, dim_out, H/2, W/2)

1.3 标准化

  • WeightStandardizedConv2d:实现了加权标准化卷积层,使用加权标准化(weight standardization)来提高卷积层的训练效率。
python 复制代码
class WeightStandardizedConv2d(nn.Conv2d):
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        weight = self.weight
        mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()
        return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

1.4 归一化层

  • LayerNorm:自定义的层归一化,标准化每个通道的特征图。
python 复制代码
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g
  • PreNorm:先进行归一化,再传递给后续函数处理。
python 复制代码
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

1.5 提取函数 extract

python 复制代码
def extract(a, t, x_shape): 
    b, *_ = t.shape  # 从 t 中获取 batch 大小 b,通常是输入数据的第一个维度
    out = a.gather(-1, t)  # 使用 PyTorch 的 gather 函数,从 a 中按时间步索引 t 提取对应的参数
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))  # 将提取出的参数调整为形状 [B, 1, ..., 1],使得与 x_shape 的维度一致,便于后续广播

这段代码的目的是从参数向量 a 中按时间步索引 t 提取对应的值 ,并调整其形状[b, 1, 1, ..., 1],以兼容目标张量 x_shape 的广播操作

这种形式在扩散模型中非常常见,用于将时间相关的参数扩展到整个张量操作

输入参数

  • a:表示一个预先计算好的时间步相关参数(如扩散模型中的 alphas_cumprodsnr),通常是 [T] 的张量。
  • t:表示当前 batch 中每个样本的时间步索引([B] 的整数张量)。
  • x_shape:输入数据的形状,用于对结果进行广播。

难点理解 return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  1. b

    • 表示批量大小,通常是输入数据的第一个维度。
    • 例如,如果 x_shape = (8, 3, 32, 32)(批大小为 8),则 b = 8
  2. len(x_shape) - 1

    • 表示目标张量的维度数减去 1。
    • 例如,对于 x_shape = (8, 3, 32, 32)len(x_shape) = 4,因此 len(x_shape) - 1 = 3
  3. (1,) * (len(x_shape) - 1)

    • 生成一个包含 len(x_shape) - 1 个值为 1 的元组。
    • 例如,如果 len(x_shape) - 1 = 3,则结果是 (1, 1, 1)
  4. out.reshape(b, *((1,) * (len(x_shape) - 1)))

    • out 的形状调整为 [b, 1, 1, 1, ...],以便其与 x_shape 的维度兼容。
    • b 决定第一个维度,其余维度填充 1,为后续广播做准备。

在 PyTorch 中,广播机制会自动扩展维度为 1 的张量,以匹配目标张量的维度。例如:

  • 如果目标张量的形状是 [8, 3, 32, 32],一个形状为 [8, 1, 1, 1] 的张量会被广播为 [8, 3, 32, 32]

【举例说明】

输入数据:

python 复制代码
a = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])  # [2, 3]
t = torch.tensor([0, 2])  # 每个 batch 的时间步索引
x_shape = (2, 3, 32, 32)  # 假设目标张量形状

执行过程:

  1. a.gather(-1, t)

    • gather 根据 ta 的最后一维提取值。a 的形状是 [2, 3],它的最后一维大小为 3
    python 复制代码
    a = [[0.1, 0.2, 0.3],   # 第一批次
         [0.4, 0.5, 0.6]]   # 第二批次
    • 对第一批次(索引 t[0] = 0):从 [0.1, 0.2, 0.3] 提取第 0 个值,得到 0.1
    • 对第二批次(索引 t[1] = 2):从 [0.4, 0.5, 0.6] 提取第 2 个值,得到 0.6
    • 结果:out = [0.1, 0.6],形状为 [2]
  2. (1,) * (len(x_shape) - 1)

    • x_shape = (2, 3, 32, 32),因此 len(x_shape) - 1 = 3
    • 结果:(1, 1, 1)
  3. out.reshape(b, *((1,) * (len(x_shape) - 1)))

    • b = 2,结果形状为 [2, 1, 1, 1]

out(形状为 [2, 1, 1, 1])与目标张量(例如 [2, 3, 32, 32])相乘时:out 会自动广播为 [2, 3, 32, 32]

2. 构建网络块类

2.1 block()

python 复制代码
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

这段代码定义了一个名为 Block 的类,它继承自 nn.Module,通常用于构建神经网络模型中的模块。这个 Block 类包括一个卷积层、归一化层、激活函数,以及可选的 scale_shift 操作。

整体作用

  • Block 类构建了一个标准的深度学习模块,常用于卷积神经网络中的基本单元。
  • 先通过标准化卷积层提取特征 ,再进行分组归一化 ,最后使用 SiLU 激活函数进行非线性变换scale_shift 提供了额外的灵活性,使得模块可以在某些应用中进行动态调整。

初始化方法__init__ 方法用于初始化类的实例

  • 参数:
    • dim: 输入通道的数量。
    • dim_out: 输出通道的数量。
    • groups: 归一化层的分组数,默认为 8。
  • super().__init__() 调用父类 nn.Module 的构造方法。
  • 初始化的组件:
    • self.proj:一个卷积层,使用自定义的 WeightStandardizedConv2d(这是一个标准化权重的卷积层),卷积核大小为 3,填充为 1。
    • self.normnn.GroupNorm 归一化层,使用 groups 参数对输入通道进行分组归一化。
    • self.act:激活函数 nn.SiLU(),一种平滑的激活函数,也称为 Swish 激活函数。

forward 前向方法定义了数据如何通过网络进行传递

  • 参数:

    • x:输入张量。
    • scale_shift:一个可选参数,包含 (scale, shift),用于缩放和平移 x
  • 前向过程:

    • x = self.proj(x)将输入 x 通过卷积层进行卷积操作
    • x = self.norm(x)将卷积后的结果进行分组归一化
    • 检查 scale_shift 是否存在(使用 exists(scale_shift))。如果存在,将 scale_shift 拆分为 scaleshift,其中 scale 是缩放因子,shift 是偏移量 。并应用以下变换:x = x * (scale + 1) + shift,将 x 进行缩放和偏移调整。
    • x = self.act(x):将结果通过激活函数 SiLU 激活。
    • 返回结果 x

2.2 残差连接

Residual :残差连接模块,接受一个函数作为输入, 返回该函数的输出与输入的和

python 复制代码
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

2.3 ResnetBlock

python 复制代码
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)

        return h + self.res_conv(x)

这段代码定义了一个 ResnetBlock 类,它是用于深度学习模型中的残差块,继承自 nn.ModuleResnetBlock 包括两个 Block 层和一个残差连接,并可以在需要时接受时间嵌入(time_emb)用于调节网络行为。

__init__ 方法 初始化组件:

  • self.mlp: 一个 MLP 层,用于将时间嵌入投影到 dim_out * 2 维度。如果 time_emb_dim 存在,则创建此 MLP;否则为 None
  • self.block1self.block2: 两个 Block 实例,分别用于特征提取
  • self.res_conv: 一个 1x1 卷积层(或恒等映射),用于调整输入 x 与输出 h 在维度上的匹配。如果 dimdim_out 不相同,则使用卷积层进行调整。

具体来说,

self.block1 = Block(dim, dim_out, groups = groups)

self.block2 = Block(dim_out, dim_out, groups = groups)

这两行代码实例化了两个 Block 层,并在 ResnetBlock 中分别赋值给 self.block1self.block2。每个 Block 对象的初始化会执行 Block 类的 __init__ 方法中的内容,

  • 当创建 self.block1 时,Block 的 __init__ 方法执行以下步骤:
    self.proj 初始化为一个 WeightStandardizedConv2d 卷积层,它的输入通道数为 dim,输出通道数为 dim_out,卷积核大小为 3,并且带有 1 个像素的填充
    self.norm 初始化为 GroupNorm,用于将输出通道 dim_out 进行分组归一化。
    self.act 初始化为 SiLU 激活函数。
  • 创建 self.block2 时,执行了 Block 的 __init__ 方法,但与 self.block1 的主要区别是:
    这次 proj 卷积层的输入通道和输出通道都是 dim_out ,使得输出通道数保持不变

forward 方法:

  • 接收输入 x 和可选的时间嵌入 time_emb
  • 如果 self.mlp 存在且 time_emb 存在,则将 time_emb 通过 MLP,并通过 rearrange 重塑为适合卷积操作的形状。然后将其切分为 scaleshift(用于 scale_shift 操作)。
  • 通过 block1 进行前向传播,并应用 scale_shift
  • 经过 block2 进一步处理。
  • 最后,返回 h + self.res_conv(x),实现残差连接

2.3 Attention

python 复制代码
class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
        q = q * self.scale
        sim = einsum('b h d i, b h d j -> b h i j', q, k)
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

这段代码定义了一个用于图像数据的自注意力机制 模块,Attention,它继承自 nn.Module 并用于深度学习模型中。

  1. __init__ 方法:

    • 接收 dim(输入通道数)、heads(多头注意力头的数量)、dim_head(每个注意力头的维度)作为参数。
    • 计算 hidden_dimdim_head * heads,这是多头注意力中每个 QKV 的总维度。
    • self.scale 用于缩放 Q 向量,以稳定训练。
    • self.to_qkv 是一个 1x1 卷积层,将输入的特征图变换为 Q(查询)、K(键)和 V(值)向量 ,输出通道数是 hidden_dim * 3
    • self.to_out 是另一个 1x1 卷积层,用于将注意力机制的输出映射回输入的维度 dim
  2. forward 方法:

    • 输入 x 是一个四维张量,形状为 (batch_size, channels, height, width)
    • 使用 self.to_qkv(x) 生成 QKV,并通过 chunk(3, dim=1) 将它们分开。
    • 使用 map()rearrange()QKV 重排,使它们适应多头注意力的形状 (batch_size, heads, dim_head, tokens),其中 tokens = height * width
    • Q 向量乘以 self.scale 进行缩放。
    • 计算相似度矩阵 sim,通过 einsum('b h d i, b h d j -> b h i j', q, k) 实现,表示 QK 之间的点积。
    • 使用 softmax 计算注意力权重 attn
    • 计算注意力加权后的输出 out,通过 einsum('b h i j, b h d j -> b h i d', attn, v) 将注意力矩阵应用于 V
    • 重排输出形状回 (batch_size, channels, height, width)
    • 最后,使用 self.to_out(out) 将输出映射回原输入通道数。

该模块用于提取图像特征的自注意力机制,帮助模型在处理复杂输入数据时,捕获长距离依赖和上下文信息。

相关推荐
思则变1 小时前
[Pytest] [Part 2]增加 log功能
开发语言·python·pytest
漫谈网络1 小时前
WebSocket 在前后端的完整使用流程
javascript·python·websocket
try2find3 小时前
安装llama-cpp-python踩坑记
开发语言·python·llama
博观而约取4 小时前
Django ORM 1. 创建模型(Model)
数据库·python·django
精灵vector5 小时前
构建专家级SQL Agent交互
python·aigc·ai编程
Zonda要好好学习6 小时前
Python入门Day2
开发语言·python
Vertira6 小时前
pdf 合并 python实现(已解决)
前端·python·pdf
太凉6 小时前
Python之 sorted() 函数的基本语法
python
项目題供诗6 小时前
黑马python(二十四)
开发语言·python
晓13137 小时前
OpenCV篇——项目(二)OCR文档扫描
人工智能·python·opencv·pycharm·ocr