1 模型训练
python
复制代码
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
first_iter = 0
# 初始化高斯模型,用于表示场景中的每个点的3D高斯分布
gaussians = GaussianModel(dataset.sh_degree)
# 初始化场景对象,加载数据集和对应的相机参数
scene = Scene(dataset, gaussians)
# 为高斯模型参数设置优化器和学习率调度器
gaussians.training_setup(opt)
# 如果提供了checkpoint,则从checkpoint加载模型参数并恢复训练进度
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
# 设置背景颜色,白色或黑色取决于数据集要求
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
# 创建CUDA事件用于计时
iter_start = torch.cuda.Event(enable_timing=True)
iter_end = torch.cuda.Event(enable_timing=True)
viewpoint_stack = None
ema_loss_for_log = 0.0
# 使用tqdm库创建进度条,追踪训练进度
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
# 记录迭代开始时间
iter_start.record()
# 根据当前迭代次数更新学习率
gaussians.update_learning_rate(iteration)
# 每1000次迭代,提升球谐函数的次数以改进模型复杂度
if iteration % 1000 == 0:
gaussians.oneupSHdegree()
# 随机选择一个训练用的相机视角
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
# 如果达到调试起始点,启用调试模式
if (iteration - 1) == debug_from:
pipe.debug = True
# 根据设置决定是否使用随机背景颜色
bg = torch.rand((3), device="cuda") if opt.random_background else background
# 渲染当前视角的图像
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# 计算渲染图像与真实图像之间的损失
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()
# 记录迭代结束时间
iter_end.record()
with torch.no_grad():
# 更新进度条和损失显示
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()
# 定期记录训练数据并保存模型
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if iteration in saving_iterations:
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)
# 在指定迭代区间内,对3D高斯模型进行增密和修剪
if iteration < opt.densify_until_iter:
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()
# 执行优化器的一步,并准备下一次迭代
if iteration < opt.iterations:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none=True)
# 定期保存checkpoint
if iteration in checkpoint_iterations:
print("\n[ITER {}] Saving Checkpoint".format(iteration))
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
2 数据加载
python
复制代码
class Scene:
"""
Scene 类用于管理场景的3D模型,包括相机参数、点云数据和高斯模型的初始化和加载
"""
def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""
初始化场景对象
:param args: 包含模型路径和源路径等模型参数
:param gaussians: 高斯模型对象,用于场景点的3D表示
:param load_iteration: 指定加载模型的迭代次数,如果为-1,则自动寻找最大迭代次数
:param shuffle: 是否在训练前打乱相机列表
:param resolution_scales: 分辨率比例列表,用于处理不同分辨率的相机
"""
self.model_path = args.model_path # 模型文件保存路径
self.loaded_iter = None # 已加载的迭代次数
self.gaussians = gaussians # 高斯模型对象
# 检查并加载已有的训练模型
if load_iteration:
if load_iteration == -1:
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else:
self.loaded_iter = load_iteration
print(f"Loading trained model at iteration {self.loaded_iter}")
self.train_cameras = {} # 用于训练的相机参数
self.test_cameras = {} # 用于测试的相机参数
# 根据数据集类型(COLMAP或Blender)加载场景信息
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
else:
assert False, "Could not recognize scene type!"
# 如果是初次训练,初始化3D高斯模型;否则,加载已有模型
if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply"))
else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
# 根据resolution_scales加载不同分辨率的训练和测试位姿
for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
def save(self, iteration):
"""
保存当前迭代下的3D高斯模型点云。
:param iteration: 当前的迭代次数。
"""
point_cloud_path = os.path.join(self.model_path, f"point_cloud/iteration_{iteration}")
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
def getTrainCameras(self, scale=1.0):
"""
获取指定分辨率比例的训练相机列表
:param scale: 分辨率比例
:return: 指定分辨率比例的训练相机列表
"""
return self.train_cameras[scale]
sceneLoadTypeCallbacks = {
"Colmap": readColmapSceneInfo,
"Blender" : readNerfSyntheticInfo
}
python
复制代码
def readColmapSceneInfo(path, images, eval, llffhold=8):
# 尝试读取COLMAP处理结果中的二进制相机外参和内参文件
try:
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
except:
# 如果二进制文件读取失败,尝试读取文本格式的相机外参和内参文件
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
# 定义存放图片的目录,如果未指定则默认为"images"
reading_dir = "images" if images is None else images
# 读取并处理相机参数,转换为内部使用的格式
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
# 根据图片名称对相机信息进行排序,以保证顺序一致性
cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
# 根据是否为评估模式(eval),将相机分为训练集和测试集
# 如果为评估模式,根据llffhold参数(通常用于LLFF数据集)间隔选择测试相机
if eval:
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
else:
# 如果不是评估模式,所有相机均为训练相机,测试相机列表为空
train_cam_infos = cam_infos
test_cam_infos = []
# 计算场景归一化参数,这是为了处理不同尺寸和位置的场景,使模型训练更稳定
nerf_normalization = getNerfppNorm(train_cam_infos)
# 尝试读取点云数据,优先从PLY文件读取,如果不存在,则尝试从BIN或TXT文件转换并保存为PLY格式
ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin")
txt_path = os.path.join(path, "sparse/0/points3D.txt")
if not os.path.exists(ply_path):
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
try:
xyz, rgb, _ = read_points3D_binary(bin_path)
except:
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb)
try:
pcd = fetchPly(ply_path)
except:
pcd = None
# 组装场景信息,包括点云、训练用相机、测试用相机、场景归一化参数和点云文件路径
scene_info = SceneInfo(point_cloud=pcd,
train_cameras=train_cam_infos,
test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization,
ply_path=ply_path)
return scene_info
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
cam_infos = [] # 初始化用于存储相机信息的列表
# 遍历所有相机的外参
for idx, key in enumerate(cam_extrinsics):
# 动态显示读取相机信息的进度
sys.stdout.write('\r')
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
sys.stdout.flush()
# 获取当前相机的外参和内参
extr = cam_extrinsics[key] # 当前相机的外参
intr = cam_intrinsics[extr.camera_id] # 根据外参中的camera_id找到对应的内参
height = intr.height # 相机图片的高度
width = intr.width # 相机图片的宽度
uid = intr.id # 相机的唯一标识符
# 将四元数表示的旋转转换为旋转矩阵R
R = np.transpose(qvec2rotmat(extr.qvec))
# 外参中的平移向量
T = np.array(extr.tvec)
# 根据相机内参模型计算视场角(FoV)
if intr.model == "SIMPLE_PINHOLE":
# 如果是简单针孔模型,只有一个焦距参数
focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height) # 计算垂直方向的视场角
FovX = focal2fov(focal_length_x, width) # 计算水平方向的视场角
elif intr.model == "PINHOLE":
# 如果是针孔模型,有两个焦距参数
focal_length_x = intr.params[0]
focal_length_y = intr.params[1]
FovY = focal2fov(focal_length_y, height) # 使用y方向的焦距计算垂直视场角
FovX = focal2fov(focal_length_x, width) # 使用x方向的焦距计算水平视场角
else:
# 如果不是以上两种模型,抛出错误
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
# 构建图片的完整路径
image_path = os.path.join(images_folder, os.path.basename(extr.name))
image_name = os.path.basename(image_path).split(".")[0] # 提取图片名称,不包含扩展名
# 使用PIL库打开图片文件
image = Image.open(image_path)
# 创建并存储相机信息
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=width, height=height)
cam_infos.append(cam_info)
# 在读取完所有相机信息后换行
sys.stdout.write('\n')
# 返回整理好的相机信息列表
return cam_infos