3D Gaussian Splatting代码详解(一):模型训练、数据加载

1 模型训练

python 复制代码
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
    first_iter = 0
    # 初始化高斯模型,用于表示场景中的每个点的3D高斯分布
    gaussians = GaussianModel(dataset.sh_degree)
    # 初始化场景对象,加载数据集和对应的相机参数
    scene = Scene(dataset, gaussians)
    # 为高斯模型参数设置优化器和学习率调度器
    gaussians.training_setup(opt)
    # 如果提供了checkpoint,则从checkpoint加载模型参数并恢复训练进度
    if checkpoint:
        (model_params, first_iter) = torch.load(checkpoint)
        gaussians.restore(model_params, opt)
    # 设置背景颜色,白色或黑色取决于数据集要求
    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    # 创建CUDA事件用于计时
    iter_start = torch.cuda.Event(enable_timing=True)
    iter_end = torch.cuda.Event(enable_timing=True)

    viewpoint_stack = None
    ema_loss_for_log = 0.0
    # 使用tqdm库创建进度条,追踪训练进度
    progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
    first_iter += 1
    for iteration in range(first_iter, opt.iterations + 1):
        # 记录迭代开始时间
        iter_start.record()

        # 根据当前迭代次数更新学习率
        gaussians.update_learning_rate(iteration)

        # 每1000次迭代,提升球谐函数的次数以改进模型复杂度
        if iteration % 1000 == 0:
            gaussians.oneupSHdegree()

        # 随机选择一个训练用的相机视角
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))

        # 如果达到调试起始点,启用调试模式
        if (iteration - 1) == debug_from:
            pipe.debug = True

        # 根据设置决定是否使用随机背景颜色
        bg = torch.rand((3), device="cuda") if opt.random_background else background

        # 渲染当前视角的图像
        render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
        image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

        # 计算渲染图像与真实图像之间的损失
        gt_image = viewpoint_cam.original_image.cuda()
        Ll1 = l1_loss(image, gt_image)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
        loss.backward()

        # 记录迭代结束时间
        iter_end.record()

        with torch.no_grad():
            # 更新进度条和损失显示
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
            if iteration % 10 == 0:
                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()

            # 定期记录训练数据并保存模型
            training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
            if iteration in saving_iterations:
                print("\n[ITER {}] Saving Gaussians".format(iteration))
                scene.save(iteration)

            # 在指定迭代区间内,对3D高斯模型进行增密和修剪
            if iteration < opt.densify_until_iter:
                gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

                if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
                    size_threshold = 20 if iteration > opt.opacity_reset_interval else None
                    gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

                if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
                    gaussians.reset_opacity()

            # 执行优化器的一步,并准备下一次迭代
            if iteration < opt.iterations:
                gaussians.optimizer.step()
                gaussians.optimizer.zero_grad(set_to_none=True)

            # 定期保存checkpoint
            if iteration in checkpoint_iterations:
                print("\n[ITER {}] Saving Checkpoint".format(iteration))
                torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")

2 数据加载

python 复制代码
class Scene:
    """
    Scene 类用于管理场景的3D模型,包括相机参数、点云数据和高斯模型的初始化和加载
    """

    def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
        """
        初始化场景对象
        
        :param args: 包含模型路径和源路径等模型参数
        :param gaussians: 高斯模型对象,用于场景点的3D表示
        :param load_iteration: 指定加载模型的迭代次数,如果为-1,则自动寻找最大迭代次数
        :param shuffle: 是否在训练前打乱相机列表
        :param resolution_scales: 分辨率比例列表,用于处理不同分辨率的相机
        """
        self.model_path = args.model_path  # 模型文件保存路径
        self.loaded_iter = None  # 已加载的迭代次数
        self.gaussians = gaussians  # 高斯模型对象

        # 检查并加载已有的训练模型
        if load_iteration:
            if load_iteration == -1:
                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
            else:
                self.loaded_iter = load_iteration
            print(f"Loading trained model at iteration {self.loaded_iter}")

        self.train_cameras = {}  # 用于训练的相机参数
        self.test_cameras = {}  # 用于测试的相机参数

        # 根据数据集类型(COLMAP或Blender)加载场景信息
        if os.path.exists(os.path.join(args.source_path, "sparse")):
            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
            print("Found transforms_train.json file, assuming Blender data set!")
            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
        else:
            assert False, "Could not recognize scene type!"

        # 如果是初次训练,初始化3D高斯模型;否则,加载已有模型
        if self.loaded_iter:
            self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))
        else:
            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)

        # 根据resolution_scales加载不同分辨率的训练和测试位姿
        for resolution_scale in resolution_scales:
            print("Loading Training Cameras")
            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
            print("Loading Test Cameras")
            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)

    def save(self, iteration):
        """
        保存当前迭代下的3D高斯模型点云。
        
        :param iteration: 当前的迭代次数。
        """
        point_cloud_path = os.path.join(self.model_path, f"point_cloud/iteration_{iteration}")
        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))

    def getTrainCameras(self, scale=1.0):
        """
        获取指定分辨率比例的训练相机列表
        
        :param scale: 分辨率比例
        :return: 指定分辨率比例的训练相机列表
        """
        return self.train_cameras[scale]
    
sceneLoadTypeCallbacks = {
    "Colmap": readColmapSceneInfo,
    "Blender" : readNerfSyntheticInfo
}
python 复制代码
def readColmapSceneInfo(path, images, eval, llffhold=8):
    # 尝试读取COLMAP处理结果中的二进制相机外参和内参文件
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        # 如果二进制文件读取失败,尝试读取文本格式的相机外参和内参文件
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    # 定义存放图片的目录,如果未指定则默认为"images"
    reading_dir = "images" if images is None else images
    # 读取并处理相机参数,转换为内部使用的格式
    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
    # 根据图片名称对相机信息进行排序,以保证顺序一致性
    cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)

    # 根据是否为评估模式(eval),将相机分为训练集和测试集
    # 如果为评估模式,根据llffhold参数(通常用于LLFF数据集)间隔选择测试相机
    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        # 如果不是评估模式,所有相机均为训练相机,测试相机列表为空
        train_cam_infos = cam_infos
        test_cam_infos = []

    # 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定
    nerf_normalization = getNerfppNorm(train_cam_infos)

    # 尝试读取点云数据,优先从PLY文件读取,如果不存在,则尝试从BIN或TXT文件转换并保存为PLY格式
    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    # 组装场景信息,包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径
    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []  # 初始化用于存储相机信息的列表

    # 遍历所有相机的外参
    for idx, key in enumerate(cam_extrinsics):
        # 动态显示读取相机信息的进度
        sys.stdout.write('\r')
        sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
        sys.stdout.flush()

        # 获取当前相机的外参和内参
        extr = cam_extrinsics[key]  # 当前相机的外参
        intr = cam_intrinsics[extr.camera_id]  # 根据外参中的camera_id找到对应的内参
        height = intr.height  # 相机图片的高度
        width = intr.width  # 相机图片的宽度

        uid = intr.id  # 相机的唯一标识符
        # 将四元数表示的旋转转换为旋转矩阵R
        R = np.transpose(qvec2rotmat(extr.qvec))
        # 外参中的平移向量
        T = np.array(extr.tvec)

        # 根据相机内参模型计算视场角(FoV)
        if intr.model == "SIMPLE_PINHOLE":
            # 如果是简单针孔模型,只有一个焦距参数
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)  # 计算垂直方向的视场角
            FovX = focal2fov(focal_length_x, width)  # 计算水平方向的视场角
        elif intr.model == "PINHOLE":
            # 如果是针孔模型,有两个焦距参数
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)  # 使用y方向的焦距计算垂直视场角
            FovX = focal2fov(focal_length_x, width)  # 使用x方向的焦距计算水平视场角
        else:
            # 如果不是以上两种模型,抛出错误
            assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        # 构建图片的完整路径
        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        image_name = os.path.basename(image_path).split(".")[0]  # 提取图片名称,不包含扩展名
        # 使用PIL库打开图片文件
        image = Image.open(image_path)

        # 创建并存储相机信息
        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                              image_path=image_path, image_name=image_name, width=width, height=height)
        cam_infos.append(cam_info)
    
    # 在读取完所有相机信息后换行
    sys.stdout.write('\n')
    # 返回整理好的相机信息列表
    return cam_infos
相关推荐
ZHOU_WUYI3 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1233 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界4 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221514 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2514 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街4 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台5 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网
加密新世界5 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
hunteritself5 小时前
ChatGPT高级语音模式正在向Web网页端推出!
人工智能·gpt·chatgpt·openai·语音识别
Che_Che_6 小时前
Cross-Inlining Binary Function Similarity Detection
人工智能·网络安全·gnn·二进制相似度检测