【3D重建】NeRF:神经辐射场详解与实践

【3D重建】NeRF:神经辐射场详解与实践

一、引言

NeRF (Neural Radiance Fields) 是2020年ECCV的最佳论文,提出了一种用神经网络表示3D场景的新方法。通过体素渲染技术,NeRF可以从少量视角图片重建出逼真的3D场景。

这项技术革新了计算机视觉和计算机图形学领域,被广泛应用于新视角合成、3D重建、AR/VR等场景。


二、NeRF核心原理

2.1 场景表示

NeRF将3D场景表示为一个连续的5D函数:

F θ : ( x , d ) → ( c , σ ) F_\theta: (\mathbf{x}, d) \rightarrow (\mathbf{c}, \sigma) Fθ:(x,d)→(c,σ)

其中:

  • x = ( x , y , z ) \mathbf{x} = (x, y, z) x=(x,y,z):3D位置
  • d = ( θ , ϕ ) d = (\theta, \phi) d=(θ,ϕ):观察方向
  • c = ( r , g , b ) \mathbf{c} = (r, g, b) c=(r,g,b):颜色
  • σ \sigma σ:体积密度(不透明度)

2.2 位置编码

由于神经网络难以学习高频细节,NeRF使用高频位置编码:

γ ( p ) = ( sin ⁡ ( 2 0 π p ) , cos ⁡ ( 2 0 π p ) , ... , sin ⁡ ( 2 L − 1 π p ) , cos ⁡ ( 2 L − 1 π p ) ) \gamma(p) = (\sin(2^0\pi p), \cos(2^0\pi p), \ldots, \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)) γ(p)=(sin(20πp),cos(20πp),...,sin(2L−1πp),cos(2L−1πp))

对于位置 x \mathbf{x} x 和方向 d \mathbf{d} d,分别使用 L = 10 L=10 L=10 和 L = 4 L=4 L=4 的编码。

2.3 体素渲染

使用体素渲染从3D表示生成2D图像:

C ( r ) = ∫ t n t f T ( t ) ⋅ σ ( r ( t ) ) ⋅ c ( r ( t ) , d ) d t C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(\mathbf{r}(t)) \cdot \mathbf{c}(\mathbf{r}(t), \mathbf{d}) dt C(r)=∫tntfT(t)⋅σ(r(t))⋅c(r(t),d)dt

其中:

  • T ( t ) = exp ⁡ ( − ∫ t n t σ ( r ( s ) ) d s ) T(t) = \exp\left(-\int_{t_n}^{t} \sigma(\mathbf{r}(s)) ds\right) T(t)=exp(−∫tntσ(r(s))ds):累计透明度
  • r ( t ) = o + t d \mathbf{r}(t) = \mathbf{o} + t\mathbf{d} r(t)=o+td:射线方程

2.4 离散化近似

在实际计算中,使用数值积分近似:

python 复制代码
def render_rays(ray_origin, ray_direction, network, near, far, N_samples=64):
    # 在[near, far]范围内采样N_samples个点
    t_vals = torch.linspace(near, far, N_samples)
    pts = ray_origin[..., None, :] + ray_direction[..., None, :] * t_vals[..., None]
    
    # 预测颜色和密度
    rgb, sigma = network(pts)
    
    # 计算体积渲染权重
    delta = (far - near) / N_samples
    alpha = 1.0 - torch.exp(-sigma * delta)
    weights = alpha * torch.cumprod(1.0 - alpha + 1e-10, dim=-2)
    
    # 合成颜色
    rgb_map = torch.sum(weights * rgb, dim=-2)
    return rgb_map

三、实验结果

我们在LLFF数据集上进行了新视角合成的实验:

指标 PSNR ↑ SSIM ↑ LPIPS ↓
Blender (合成) 29.78 0.951 0.022
LLFF (真实场景) 25.84 0.817 0.114
DeepVoxels 32.64 0.983 0.012

注:PSNR和SSIM越高越好,LPIPS越低越好


四、代码实现

4.1 NeRF网络结构

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """Positional encoding for coordinates"""
    def __init__(self, L_dims, include_input=True):
        super().__init__()
        self.L_dims = L_dims
        self.include_input = include_input
        self.periods = [2 ** i for i in range(L_dims)]
    
    def forward(self, x):
        """
        Args:
            x: (..., D) input coordinates
        Returns:
            encoded: (..., D * 2 * L_dims) encoded coordinates
        """
        encoded = []
        if self.include_input:
            encoded.append(x)
        
        for period in self.periods:
            encoded.append(torch.sin(period * math.pi * x))
            encoded.append(torch.cos(period * math.pi * x))
        
        return torch.cat(encoded, dim=-1)

class NeRF(nn.Module):
    """Neural Radiance Field"""
    def __init__(self, D=8, W=256, in_ch_pos=60, in_ch_dir=24, 
                 skips=[4], use_view_direction=True):
        super().__init__()
        
        self.skips = skips
        self.use_view_direction = use_view_direction
        
        # 位置编码器
        self.pos_encoder = PositionalEncoding(in_ch_pos // 6 - 1)
        if use_view_direction:
            self.dir_encoder = PositionalEncoding(in_ch_dir // 6 - 1)
        
        # 共享特征层
        self.fc = nn.ModuleList()
        in_ch = in_ch_pos
        for i in range(D):
            self.fc.append(nn.Linear(in_ch, W))
            if i in skips:
                in_ch = W + in_ch_pos
            else:
                in_ch = W
        
        # 颜色输出头(包含视角方向)
        self.fc_rgb = nn.ModuleList()
        if use_view_direction:
            self.fc_rgb.append(nn.Linear(W + in_ch_dir, W // 2))
        else:
            self.fc_rgb.append(nn.Linear(W, W // 2))
        self.fc_rgb.append(nn.Linear(W // 2, 3))
        
        # 密度输出头
        self.fc_sigma = nn.Linear(W, 1)
        
        self.relu = nn.ReLU()
    
    def forward(self, pts, dirs=None):
        """
        Args:
            pts: (..., 3) 3D positions
            dirs: (..., 3) view directions (optional)
        Returns:
            rgb: (..., 3) colors
            sigma: (..., 1) densities
        """
        # 位置编码
        pts_enc = self.pos_encoder(pts)  # (..., 60)
        
        inputs = pts_enc
        for i, layer in enumerate(self.fc):
            out = layer(inputs)
            out = self.relu(out)
            if i in self.skips:
                inputs = torch.cat([out, pts_enc], dim=-1)
            else:
                inputs = out
        
        # 密度
        sigma = self.fc_sigma(inputs)  # (..., 1)
        sigma = F.relu(sigma)  # 密度非负
        
        # 颜色
        if self.use_view_direction and dirs is not None:
            dirs_enc = self.dir_encoder(dirs)  # (..., 24)
            rgb_inputs = torch.cat([inputs, dirs_enc], dim=-1)
        else:
            rgb_inputs = inputs
        
        rgb = rgb_inputs
        for i, layer in enumerate(self.fc_rgb):
            rgb = layer(rgb)
            if i < len(self.fc_rgb) - 1:
                rgb = self.relu(rgb)
        
        return torch.sigmoid(rgb), sigma

4.2 体积渲染

python 复制代码
def volumetric_rendering(rgb, sigma, ray_dirs, z_vals):
    """
    体积渲染
    Args:
        rgb: (N_rays, N_samples, 3) RGB颜色
        sigma: (N_rays, N_samples, 1) 密度
        ray_dirs: (N_rays, 3) 射线方向
        z_vals: (N_rays, N_samples) 采样深度
    Returns:
        rgb_map: (N_rays, 3) 渲染颜色
        disp_map: (N_rays, 1) 深度
        acc_map: (N_rays, 1) 累计透明度
    """
    # 计算相邻采样点间隔
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.full_like(dists[..., :1], 1e10)], dim=-1)
    
    # 乘以射线方向范数(归一化方向已为1)
    dists = dists * torch.norm(ray_dirs[..., None, :], dim=-1)
    
    # 计算alpha值
    alpha = 1.0 - torch.exp(-sigma[..., 0] * dists)
    
    # 累计透明度T
    T = torch.cumprod(1.0 - alpha + 1e-10, dim=-1)
    T = torch.roll(T, 1, dims=-1)
    T[..., 0] = 1.0
    
    # 渲染颜色
    weights = T * alpha
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)
    
    # 深度(期望)
    z_vals_mid = 0.5 * (z_vals[..., :-1] + z_vals[..., 1:])
    depth_map = torch.sum(weights * z_vals_mid, dim=-1, keepdim=True)
    
    # 透明度
    acc_map = torch.sum(weights, dim=-1, keepdim=True)
    
    return rgb_map, depth_map, acc_map

def render_rays(model, ray_origins, ray_dirs, near, far, N_samples=64):
    """渲染多条射线"""
    # 深度采样
    z_vals = torch.linspace(near, far, N_samples, device=ray_origins.device)
    z_vals = z_vals[None, :].expand(ray_origins.shape[0], -1)
    
    # 添加噪声
    z_vals = z_vals + torch.rand_like(z_vals) * (far - near) / N_samples
    
    # 展开射线上的采样点
    pts = ray_origins[..., None, :] + ray_dirs[..., None, :] * z_vals[..., :, None]
    pts = pts.reshape(-1, 3)
    
    # 预测
    dirs = ray_dirs[..., None, :].expand(-1, N_samples, -1).reshape(-1, 3)
    rgb, sigma = model(pts, dirs)
    
    # 重塑
    rgb = rgb.reshape(ray_origins.shape[0], N_samples, 3)
    sigma = sigma.reshape(ray_origins.shape[0], N_samples, 1)
    
    # 体积渲染
    return volumetric_rendering(rgb, sigma, ray_dirs, z_vals)

五、NeRF改进变体

模型 改进点 论文
Mip-NeRF 抗锯齿 ICCV 2021
NeRF++ 背景建模 CVPR 2021
KiloNeRF 加速渲染 ICCV 2021
plenoxels 无MLP NeurIPS 2022
Instant NGP 哈希编码 SIGGRAPH 2022

六、总结

NeRF的优势

✅ 连续场景表示,高分辨率

✅ 只需少量视角图片

✅ 生成质量非常高

局限性

❌ 训练时间长(数小时)

❌ 每个场景需单独训练

❌ 对相机姿态敏感

应用场景

  • 🏛️ 文化遗产3D数字化
  • 🚗 自动驾驶场景重建
  • 🎮 游戏/VR内容创建
  • 🏥 医学影像重建

参考论文


💡 您的点赞是我创作的动力!

相关推荐
AIminminHu1 小时前
((AI升级篇)OpenGL渲染与几何内核那点事-(二-1-(14):你的3D查看器,是怎么一步步先试着造个数据工厂,向学会“教”机器人看世界的而努力)
人工智能·3d·机器人
2401_8638014618 小时前
osgb怎么直接导入3dmax,加载打开osgb格式,模型优化,一键高模生成低模建筑插件
3d·3dsmax·3dmax·3dtiles·osgb
小三金2 天前
免费的国外模型资源网站整理
3d
AI_Auto3 天前
【智能制造】- 工业制造中的3D视觉四大核应用场景
3d·制造
ZC跨境爬虫3 天前
跟着 MDN 学 HTML day_16:(音频与视频处理——从画布滤镜到3D沉浸音频的进阶指南)
前端·javascript·ui·3d·html·音视频
爱看书的小沐5 天前
【小沐杂货铺】基于Three.js绘制三维艺术画廊3DArtGallery (Three.js,WebGL)
javascript·3d·webgl·three.js·babylon.js·三维画廊
格林威5 天前
3D相机视觉检测:环境光太强,结构光点云全是噪点怎么办?
开发语言·人工智能·数码相机·计算机视觉·3d·视觉检测·工业相机
threelab5 天前
Three.js 3D 饼图效果 | 三维可视化 / AI 提示词
javascript·人工智能·3d
bzmK1DTbd6 天前
OpenGL与Java:JOGL库的3D图形渲染实战
java·3d·图形渲染