从0开始用 PyTorch 构建完整的 NeRF

未分类1个月前发布 tree
14 0 0
↑ 点击蓝字 关注极市平台
从0开始用 PyTorch 构建完整的 NeRF
作者丨匡吉
来源丨深蓝AI
编辑丨极市平台

极市导读

 

笔者通过整理分析了NeRF论文和相关参考代码,将为读者朋友讲述利用PyTorch框架,从0到1简单复现一个NeRF(神经辐射场)的实现细节和过程。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

在解释代码之前,首先对NeRF(神经辐射场)的原理与含义进行简单回顾。而NeRF论文中是这样解释NeRF算法流程的:

“我们提出了一个当前最优的方法,应用于复杂场景下合成新视图的任务,具体的实现原理是使用一个稀疏的输入视图集合,然后不断优化底层的连续体素场景函数。我们的算法,使用一个全连接(非卷积)的深度网络,表示一个场景,这个深度网络的输入是一个单独的5D坐标(空间位置(x,y,z)和视图方向(xita,sigma)),其对应的输出则是体素密度和视图关联的辐射向量。我们通过查询沿着相机射线的5D坐标合成新的场景视图,以及通过使用经典的体素渲染技术将输出颜色和密度投射到图像中。因为体素渲染具有天然的可变性,所以优化我们的表示方法所需的唯一输入就是一组已知相机位姿的图像。我们介绍如何高效优化神经辐射场照度,以渲染具有复杂几何形状和外观的逼真新颖视图,并展示了由于之前神经渲染和视图合成工作的结果。”

从0开始用 PyTorch 构建完整的 NeRF
▲图1|NeRF实现流程©️【深蓝AI】

基于前文的原理,本节开始讲述具体的代码实现。首先,导入算法需要的Python库文件。

import os  
from typing import Optional,Tuple,List,Union,Callable  
  
import numpy as np  
import torch  
from torch import nn  
import matplotlib.pyplot as plt  
from mpl_toolkits.mplot3d import axes3d  
from tqdm import trange  
  
# 设置GPU还是CPU设备  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1 输入

根据相关论文中的介绍可知,NeRF的输入是一个包含空间位置坐标与视图方向的5D坐标。然而,在PyTorch构建NeRF过程中使用的数据集只是一般的3D到2D图像数据集,包含拍摄相机的内参:位姿和焦距。因此在后面的操作中,我们会把输入数据集转为算法模型需要的输入形式。

在这一流程中使用乐高推土机图像作为简单NeRF算法的数据集,如图2所示:(具体的数据链接请在文末查看)

从0开始用 PyTorch 构建完整的 NeRF
▲图2|乐高推土机数据集©️【深蓝AI】

这项工作中使用的小型乐高数据集由 106 幅乐高推土机的图像组成,并配有位姿数据和常用焦距数值。与其他数据集一样,这里保留前 100 张图像用于训练,并保留一张测试图像用于验证,具体的加载数据操作如下:

data = np.load('tiny_nerf_data.npz'# 加载数据集  
images = data['images']  # 图像数据  
poses = data['poses']  # 位姿数据  
focal = data['focal']  # 焦距数值  
  
  
print(f'Images shape: {images.shape}')  
print(f'Poses shape: {poses.shape}')  
print(f'Focal length: {focal}')  
  
  
  
height, width = images.shape[1:3]  
near, far = 2., 6.  
  
  
  
n_training = 100 # 训练数据数量  
testimg_idx = 101 # 测试数据下标  
testimg, testpose = images[testimg_idx], poses[testimg_idx]  
  
  
plt.imshow(testimg)  
print('Pose')  
print(testpose)

2 数据处理

回顾NeRF相关论文, 本次代码实现需要的输入是一个单独的5D坐标 (空间位置 和视图方向 , sigma ))。因此, 我们需要针对上面使用的小型乐高数据做一个处理操作。

一般而言,为了收集这些特点输入数据,算法中需要对输入图像进行反渲染操作。具体来讲就是通过每个像素点在三维空间中绘制投影线,并从中提取样本。

要从图像以外的三维空间采样输入数据点,首先就得从乐高照片集中获取每台相机的初始位姿,然后通过一些矢量数学运算,将这些4×4姿态矩阵转换成「表示原点的三维坐标和表示方向的三维矢量」——这两类信息最终会结合起来描述一个矢量,该矢量用以表征拍摄照片时相机的指向。

下列代码则正是通过绘制箭头来描述这一操作,箭头表示每一帧图像的原点和方向:

# 方向数据  
dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses])  
# 原点数据  
origins = poses[:, :3, -1]  
  
  
  
# 绘图的设置  
  
ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d')  
_ = ax.quiver(  
  origins[..., 0].flatten(),  
  origins[..., 1].flatten(),  
  origins[..., 2].flatten(),  
  dirs[..., 0].flatten(),  
  dirs[..., 1].flatten(),  
  dirs[..., 2].flatten(), length=0.5, normalize=True)  
  
ax.set_xlabel('X')  
  
ax.set_ylabel('Y')  
  
ax.set_zlabel('z')  
  
plt.show()

最终绘制出来的箭头结果如下图所示:

从0开始用 PyTorch 构建完整的 NeRF
▲图3|采样点相机拍摄指向©️【深蓝AI】

当有了这些相机位姿数据之后,我们就可以沿着图像的每个像素找到投影线,而每条投影线都是由其原点(x,y,z)和方向联合定义。其中每个像素的原点可能相同,但方向一般是不同的。这些方向射线都略微偏离中心,因此不会存在两条平行方向线,如下图所示:

从0开始用 PyTorch 构建完整的 NeRF
▲图4|相机内参示意图©️【深蓝AI】编译

根据图4所述的原理,我们就可以确定每条射线的方向和原点,相关代码如下:

def get_rays(  
  height: int, # 图像高度  
  width: int, # 图像宽带  
  focal_length: float# 焦距  
  c2w: torch.Tensor  
) -> Tuple[torch.Tensor, torch.Tensor]:  
  """  
  通过每个像素和相机原点,找到射线的原点和方向。  
   
 """
  
  
  
  
  # 应用针孔相机模型收集每个像素的方向  
  i, j = torch.meshgrid(  
      torch.arange(width, dtype=torch.float32).to(c2w),  
      torch.arange(height, dtype=torch.float32).to(c2w),  
      indexing='ij')  
        
  i, j = i.transpose(-1, -2), j.transpose(-1, -2)  
  
  # 方向数据  
  directions = torch.stack([(i - width * .5) / focal_length,  
                            -(j - height * .5) / focal_length,  
                            -torch.ones_like(i)  
                           ], dim=-1)  
  
  
  
  # 用相机位姿求出方向  
  rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)  
  
  
  
  # 默认所有射线原点相同  
  rays_o = c2w[:3, -1].expand(rays_d.shape)  
  return rays_o, rays_d

得到每个像素对应的射线的方向数据和原点数据之后,就能够获得了NeRF算法中需要的五维数据输入,下面将这些数据调整为算法输入的格式:

# 转为PyTorch的tensor  
images = torch.from_numpy(data['images'][:n_training]).to(device)  
poses = torch.from_numpy(data['poses']).to(device)  
focal = torch.from_numpy(data['focal']).to(device)  
testimg = torch.from_numpy(data['images'][testimg_idx]).to(device)  
testpose = torch.from_numpy(data['poses'][testimg_idx]).to(device)  
  
  
# 针对每个图像获取射线  
height, width = images.shape[1:3]  
  
with torch.no_grad():  
  ray_origin, ray_direction = get_rays(height, width, focal, testpose)  
  
  
print('Ray Origin')  
  
print(ray_origin.shape)  
print(ray_origin[height // 2, width // 2, :])  
print('')  
  
  
print('Ray Direction')  
  
print(ray_direction.shape)  
print(ray_direction[height // 2, width // 2, :])  
print('')

2.1 分层采样

当算法输入模块有了NeRF算法需要的输入数据,也就是包含原点和方向向量组合的线条时,就可以在线条上进行采样。这一过程是采用从粗到细的采样策略,即分层采样策略。

具体来说,分层采样就是将光线分成均匀分布的小块,接着在每个小块内随机抽样。其中扰动的设置决定了是均匀取样的,还是直接简单使用分区中心作为采样点。具体操作代码如下所示:

# 采样函数定义  
def sample_stratified(  
  rays_o: torch.Tensor, # 射线原点  
  rays_d: torch.Tensor, # 射线方向  
  near: float,  
  far: float,  
  n_samples: int, # 采样数量  
  perturb: Optional[bool] = True, # 扰动设置  
  inverse_depth: bool = False  # 反向深度  
) -> Tuple[torch.Tensor, torch.Tensor]:  
  """  
  从规则的bin中沿着射线进行采样。  
  """
  
  
  # 沿着射线抓取采样点  
  t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)  
  if not inverse_depth:  
    # 由远到近线性采样  
    z_vals = near * (1.-t_vals) + far * (t_vals)  
  else:  
    # 在反向深度中线性采样  
  z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))  
  
  # 沿着射线从bins中统一采样  
  if perturb:  
    mids = .5 * (z_vals[1:] + z_vals[:-1])  
    upper = torch.concat([mids, z_vals[-1:]], dim=-1)  
    lower = torch.concat([z_vals[:1], mids], dim=-1)  
    t_rand = torch.rand([n_samples], device=z_vals.device)  
    z_vals = lower + (upper - lower) * t_rand  
  z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])  
  
  # 应用相应的缩放参数  
  pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]  
  return pts, z_vals

接着就到了对这些采样点做可视化分析的步骤。如图5中所述,未受扰动的蓝 色点是bin的“中心“,而红点对应扰动点的采样。请注意,红点与上方的蓝点略有偏移,但所有点都在远近采样设定值之间。具体代码如下:

y_vals = torch.zeros_like(z_vals)  
# 调用采样策略函数  
_, z_vals_unperturbed = sample_stratified(rays_o, rays_d, near, far, n_samples,  
                                  perturb=False, inverse_depth=inverse_depth)  
# 绘图相关

plt.plot(z_vals_unperturbed[0].cpu().numpy(), 1 + y_vals[0].cpu().numpy(), 'b-o')  
plt.plot(z_vals[0].cpu().numpy(), y_vals[0].cpu().numpy(), 'r-o')  
plt.ylim([-1, 2])  
plt.title('Stratified Sampling (blue) with Perturbation (red)')  
ax = plt.gca()  
ax.axes.yaxis.set_visible(False)  
plt.grid(True)
从0开始用 PyTorch 构建完整的 NeRF
▲图5|采样结果示意图©️【深蓝AI】

3 位置编码

与Transformer一样,NeRF也使用了位置编码器。因此NeRF就需要借助位置编码器将输入映射到更高的频率空间,以弥补神经网络在学习低频函数时的偏差。

这一环节将会为位置编码器建立一个简单的 torch.nn.Module 模块,相同的编码器可同时用于对输入样本和视图方向的编码操作。注意,这些输入被指定了不同的参数。代码如下所示:

# 位置编码类  
class PositionalEncoder(nn.Module):  
  """  
  对输入点,做sine或者consine位置编码。  
  """
  
  def __init__(  
    self,  
    d_input: int,  
    n_freqs: int,  
    log_space: bool = False  
  ):  
    super().__init__()  
    self.d_input = d_input  
    self.n_freqs = n_freqs  
    self.log_space = log_space  
    self.d_output = d_input * (1 + 2 * self.n_freqs)  
    self.embed_fns = [lambda x: x]  
  
    # 定义线性或者log尺度的频率  
    if self.log_space:  
      freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)  
    else:  
      freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)  
  
    # 替换sin和cos  
    for freq in freq_bands:  
      self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))  
      self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))  
    
  def forward(  
    self,  
    x  
  ) -> torch.Tensor:  
    """  
    实际使用位置编码的函数。  
    """
  
    return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)  

4 NeRF模型

在此,定义一个NeRF 模型——主要由线性层模块列表构成,而列表中进一步包含非线性激活函数和残差连接。该模型有一个可选的视图方向输入,如果在实例化时提供具体的方向信息,那么会改变模型结构。

(本实现基于原始论文NeRF:Representing Scenes as Neural Radiance Fields for View Synthesis 的第3节,并使用相同的默认设置)

具体代码如下所示:

# 定义NeRF模型  
class NeRF(nn.Module):  
  """  
  神经辐射场模块。  
  """
  
  def __init__(  
    self,  
    d_input: int = 3,  
    n_layers: int = 8,  
    d_filter: int = 256,  
    skip: Tuple[int] = (4,),  
    d_viewdirs: Optional[int] = None  
  ):  
    super().__init__()  
    self.d_input = d_input # 输入  
    self.skip = skip # 残差连接  
    self.act = nn.functional.relu # 激活函数  
    self.d_viewdirs = d_viewdirs # 视图方向  
  
    # 创建模型的层结构  
    self.layers = nn.ModuleList(  
      [nn.Linear(self.d_input, d_filter)] +  
      [nn.Linear(d_filter + self.d_input, d_filter) if i in skip \  
       else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)]  
    )  
  
    # Bottleneck 层  
    if self.d_viewdirs is not None:  
      # 如果使用视图方向,分离alpha和RGB  
      self.alpha_out = nn.Linear(d_filter, 1)  
      self.rgb_filters = nn.Linear(d_filter, d_filter)  
      self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)  
      self.output = nn.Linear(d_filter // 2, 3)  
    else:  
      # 如果不使用试图方向,则简单输出  
      self.output = nn.Linear(d_filter, 4)  
    
  def forward(  
    self,  
    x: torch.Tensor,  
    viewdirs: Optional[torch.Tensor] = None  
  ) -> torch.Tensor:  
    r"""  
    带有视图方向的前向传播  
    """
  
  
    # 判断是否设置视图方向  
    if self.d_viewdirs is None and viewdirs is not None:  
      raise ValueError('Cannot input x_direction if d_viewdirs was not given.')  
  
    # 运行bottleneck层之前的网络层  
    x_input = x  
    for i, layer in enumerate(self.layers):  
      x = self.act(layer(x))  
      if i in self.skip:  
        x = torch.cat([x, x_input], dim=-1)  
  
    # 运行 bottleneck  
    if self.d_viewdirs is not None:  
      # Split alpha from network output  
      alpha = self.alpha_out(x)  
  
      # 结果传入到rgb过滤器  
      x = self.rgb_filters(x)  
      x = torch.concat([x, viewdirs], dim=-1)  
      x = self.act(self.branch(x))  
      x = self.output(x)  
  
      # 拼接alpha一起作为输出  
      x = torch.concat([x, alpha], dim=-1)  
    else:  
      # 不拼接,简单输出  
      x = self.output(x)  
    return x

5 体积渲染

上面得到NeRF模型的输出结果之后,仍需将NeRF的输出转换成图像。也就是通过渲染模块对每个像素沿光线方向的所有样本进行加权求和,从而得到该像素的估计颜色值,此外每个RGB样本都会根据其Alpha值进行加权。其中Alpha值越高,表明采样区域不透明的可能性越大,因此沿射线方向越远的点越有可能被遮挡,累加乘积可确保更远处的点受到抑制。具体代码如下:

# 体积渲染  
def cumprod_exclusive(  
  tensor: torch.Tensor  
) -> torch.Tensor:  
  """  
  (Courtesy of https://github.com/krrish94/nerf-pytorch)  
  
  和tf.math.cumprod(..., exclusive=True)功能类似  
  
  参数:  
  tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1  
    is to be computed.  
  返回值:  
  cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of  
    tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).  
  """
  
  
  # 首先计算规则的cunprod  
  cumprod = torch.cumprod(tensor, -1)  
  cumprod = torch.roll(cumprod, 1, -1)  
  # 用1替换首个元素  
  cumprod[..., 0] = 1.  
    
  return cumprod  
  
# 输出到图像的函数  
def raw2outputs(  
  raw: torch.Tensor,  
  z_vals: torch.Tensor,  
  rays_d: torch.Tensor,  
  raw_noise_std: float = 0.0,  
  white_bkgd: bool = False  
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:  
  """  
  将NeRF的输出转换为RGB输出。  
  """
  
  
  # 沿着`z_vals`轴元素之间的差值.  
  dists = z_vals[..., 1:] - z_vals[..., :-1]  
  dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)  
  
  # 将每个距离乘以相应方向射线的法线,转换为现实世界中的距离(考虑非单位方向)。  
  dists = dists * torch.norm(rays_d[..., None, :], dim=-1)  
  
  # 为模型预测密度添加噪音。可用于在训练过程中对网络进行正则化(防止出现浮点伪影)。  
  noise = 0.  
  if raw_noise_std > 0.:  
    noise = torch.randn(raw[..., 3].shape) * raw_noise_std  
  
  # Predict density of each sample along each ray. Higher values imply  
  # higher likelihood of being absorbed at this point. [n_rays, n_samples]  
  alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists)  
  
  # 预测每条射线上每个样本的密度。数值越大,表示该点被吸收的可能性越大。[n_ 射线,n_样本]  
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)  
  
  # 计算RGB图的权重。  
  rgb = torch.sigmoid(raw[..., :3])  # [n_rays, n_samples, 3]  
  rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [n_rays, 3]  
  
  # 估计预测距离的深度图。  
  depth_map = torch.sum(weights * z_vals, dim=-1)  
  
  # 稀疏图  
  disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),  
                            depth_map / torch.sum(weights, -1))  
  
  # 沿着每条射线加权。  
  acc_map = torch.sum(weights, dim=-1)  
  
  # 要合成到白色背景上,请使用累积的 alpha 贴图。  
  if white_bkgd:  
    rgb_map = rgb_map + (1. - acc_map[..., None])  
  
  return rgb_map, depth_map, acc_map, weights

6 分层体积采样

事实上,三维空间中的遮挡物非常稀疏,因此大多数点对渲染图像的贡献不大。所以,对积分有贡献的区域进行超采样会有更好的效果。这里,笔者对第一组样本应用基于归一化的权重来创建整个光线的概率密度函数,然后对该密度函数应用反变换采样来收集第二组样本。具体代码如下:

# 采样概率密度函数  
def sample_pdf(  
  bins: torch.Tensor,  
  weights: torch.Tensor,  
  n_samples: int,  
  perturb: bool = False  
) -> torch.Tensor:  
  """  
  应用反向转换采样到一组加权点。  
  """
  
  
  # 正则化权重得到概率密度函数。  
  pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]]  
  
  # 将概率密度函数转为累计分布函数。  
  cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]  
  cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # [n_rays, weights.shape[-1] + 1]  
  
  # 从累计分布函数中提取样本位置。perturb == 0 时为线性。  
  if not perturb:  
    u = torch.linspace(0., 1., n_samples, device=cdf.device)  
    u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]  
  else:  
    u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # [n_rays, n_samples]  
  
  # 沿累计分布函数找出 u 值所在的索引。  
  u = u.contiguous() # 返回具有相同值的连续张量。  
  inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]  
  
  # 夹住超出范围的索引。  
  below = torch.clamp(inds - 1, min=0)  
  above = torch.clamp(inds, max=cdf.shape[-1] - 1)  
  inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]  
  
  # 从累计分布函数和相应的 bin 中心取样。  
  matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]  
  cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,  
                       index=inds_g)  
  bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1,  
                        index=inds_g)  
  
  # 将样本转换为射线长度。  
  denom = (cdf_g[..., 1] - cdf_g[..., 0])  
  denom = torch.where(denom   t = (u - cdf_g[..., 0]) / denom  
  samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])  
  
  return samples # [n_rays, n_samples]

7 整体的前向传播流程

此时应将上面所有内容整合在一起,通过模型计算一次前向传递。

由于潜在的内存问题,前向传递以“块“为单位进行计算,然后汇总到一个批次中。梯度传播是在整个批次处理完毕后进行的,因此有“块“和“批次“之分。对于内存紧张环境来说,分块处理尤为重要,因为该环境下提供的资源比原始论文中引用的资源更为有限。具体代码如下所示:

def get_chunks(  
  inputs: torch.Tensor,  
  chunksize: int = 2**15  
) -> List[torch.Tensor]:  
  """  
  输入分块。  
  """
  
  return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]  
  
def prepare_chunks(  
  points: torch.Tensor,  
  encoding_function: Callable[[torch.Tensor], torch.Tensor],  
  chunksize: int = 2**15  
) -> List[torch.Tensor]:  
  """  
  对点进行编码和分块,为 NeRF 模型做好准备。  
  """
  
  points = points.reshape((-1, 3))  
  points = encoding_function(points)  
  points = get_chunks(points, chunksize=chunksize)  
  return points  
  
def prepare_viewdirs_chunks(  
  points: torch.Tensor,  
  rays_d: torch.Tensor,  
  encoding_function: Callable[[torch.Tensor], torch.Tensor],  
  chunksize: int = 2**15  
) -> List[torch.Tensor]:  
  r"""  
  对视图方向进行编码和分块,为 NeRF 模型做好准备。  
  """
  
  viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)  
  viewdirs = viewdirs[:, None, ...].expand(points.shape).reshape((-1, 3))  
  viewdirs = encoding_function(viewdirs)  
  viewdirs = get_chunks(viewdirs, chunksize=chunksize)  
  return viewdirs  
def nerf_forward(  
  rays_o: torch.Tensor,  
  rays_d: torch.Tensor,  
  near: float,  
  far: float,  
  encoding_fn: Callable[[torch.Tensor], torch.Tensor],  
  coarse_model: nn.Module,  
  kwargs_sample_stratified: dict = None,  
  n_samples_hierarchical: int = 0,  
  kwargs_sample_hierarchical: dict = None,  
  fine_model = None,  
  viewdirs_encoding_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  
  chunksize: int = 2**15  
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:  
  """  
  计算一次前向传播  
  """
  
  
  # 设置参数  
  if kwargs_sample_stratified is None:  
    kwargs_sample_stratified = {}  
  if kwargs_sample_hierarchical is None:  
    kwargs_sample_hierarchical = {}  
    
  # 沿着每条射线的样本查询点。  
  query_points, z_vals = sample_stratified(  
      rays_o, rays_d, near, far, **kwargs_sample_stratified)  
  
  # 准备批次。  
  batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)  
  if viewdirs_encoding_fn is not None:  
    batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d,  
                                               viewdirs_encoding_fn,  
                                               chunksize=chunksize)  
  else:  
    batches_viewdirs = [None] * len(batches)  
  
  # 稀疏模型流程。  
  predictions = []  
  for batch, batch_viewdirs in zip(batches, batches_viewdirs):  
    predictions.append(coarse_model(batch, viewdirs=batch_viewdirs))  
  raw = torch.cat(predictions, dim=0)  
  raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])  
  
  # 执行可微分体积渲染,重新合成 RGB 图像。  
  rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d)  
    
  outputs = {  
      'z_vals_stratified': z_vals  
  }  
  
  if n_samples_hierarchical > 0:  
    # Save previous outputs to return.  
    rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map  
  
    # 对精细查询点进行分层抽样。  
    query_points, z_vals_combined, z_hierarch = sample_hierarchical(  
      rays_o, rays_d, z_vals, weights, n_samples_hierarchical,  
      **kwargs_sample_hierarchical)  
  
    # 像以前一样准备输入。  
    batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)  
    if viewdirs_encoding_fn is not None:  
      batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d,  
                                                 viewdirs_encoding_fn,  
                                                 chunksize=chunksize)  
    else:  
      batches_viewdirs = [None] * len(batches)  
  
    # 通过精细模型向前传递新样本。  
    fine_model = fine_model if fine_model is not None else coarse_model  
    predictions = []  
    for batch, batch_viewdirs in zip(batches, batches_viewdirs):  
      predictions.append(fine_model(batch, viewdirs=batch_viewdirs))  
    raw = torch.cat(predictions, dim=0)  
    raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])  
  
    # 执行可微分体积渲染,重新合成 RGB 图像。  
    rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals_combined, rays_d)  
      
    # 存储输出  
    outputs['z_vals_hierarchical'] = z_hierarch  
    outputs['rgb_map_0'] = rgb_map_0  
    outputs['depth_map_0'] = depth_map_0  
    outputs['acc_map_0'] = acc_map_0  
  
  # 存储输出  
  outputs['rgb_map'] = rgb_map  
  outputs['depth_map'] = depth_map  
  outputs['acc_map'] = acc_map  
  outputs['weights'] = weights  
  return outputs

到这一步骤,就几乎拥有了训练模型所需的一切模块。现在为一个简单的训练过程做一些设置,创建超参数和辅助函数,然后来训练模型。

7.1 超参数

所有用于训练的超参数都在此设置,默认值取自原始论文中数据,除非计算上有限制。在计算受限情况下,本次讨论采用的都是合理的默认值。

# 编码器  
d_input = 3           # 输入维度  
n_freqs = 10          # 输入到编码函数中的样本点数量  
log_space = True      # 如果设置,频率按对数空间缩放  
use_viewdirs = True   # 如果设置,则使用视图方向作为输入  
n_freqs_views = 4     # 视图编码功能的数量  
  
# 采样策略  
n_samples = 64         # 每条射线的空间样本数  
perturb = True         # 如果设置,则对采样位置应用噪声  
inverse_depth = False  # 如果设置,则按反深度线性采样点  
  
# 模型  
d_filter = 128          # 线性层滤波器的尺寸  
n_layers = 2            # bottleneck层数量  
skip = []               # 应用输入残差的层级  
use_fine_model = True   # 如果设置,则创建一个精细模型  
d_filter_fine = 128     # 精细网络线性层滤波器的尺寸  
n_layers_fine = 6       # 精细网络瓶颈层数  
  
# 分层采样  
n_samples_hierarchical = 64   # 每条射线的样本数  
perturb_hierarchical = False  # 如果设置,则对采样位置应用噪声  
  
# 优化器  
lr = 5e-4  # 学习率  
  
# 训练  
n_iters = 10000  
batch_size = 2**14          # 每个梯度步长的射线数量(2 的幂次)  
one_image_per_step = True   # 每个梯度步骤一个图像(禁用批处理)  
chunksize = 2**14           # 根据需要进行修改,以适应 GPU 内存  
center_crop = True          # 裁剪图像的中心部分(每幅图像裁剪一次)  
center_crop_iters = 50      # 经过这么多epoch后,停止裁剪中心  
display_rate = 25          # 每 X 个epoch显示一次测试输出  
  
# 早停  
warmup_iters = 100          # 热身阶段的迭代次数  
warmup_min_fitness = 10.0   # 在热身_iters 处继续训练的最小 PSNR 值  
n_restarts = 10             # 训练停滞时重新开始的次数  
  
# 捆绑了各种函数的参数,以便一次性传递。  
kwargs_sample_stratified = {  
    'n_samples': n_samples,  
    'perturb': perturb,  
    'inverse_depth': inverse_depth  
}  
kwargs_sample_hierarchical = {  
    'perturb': perturb  
}

7.2 训练类和函数

这一环节会创建一些用于训练的辅助函数。NeRF很容易出现局部最小值,在这种情况下,训练很快就会停滞并产生空白输出。必要时,会利用EarlyStopping重新启动训练。

# 绘制采样函数  
def plot_samples(  
  z_vals: torch.Tensor,  
  z_hierarch: Optional[torch.Tensor] = None,  
  ax: Optional[np.ndarray] = None):  
  r"""  
  绘制分层样本和(可选)分级样本。  
  """
  
  y_vals = 1 + np.zeros_like(z_vals)  
  
  if ax is None:  
    ax = plt.subplot()  
  ax.plot(z_vals, y_vals, 'b-o')  
  if z_hierarch is not None:  
    y_hierarch = np.zeros_like(z_hierarch)  
    ax.plot(z_hierarch, y_hierarch, 'r-o')  
  ax.set_ylim([-1, 2])  
  ax.set_title('Stratified  Samples (blue) and Hierarchical Samples (red)')  
  ax.axes.yaxis.set_visible(False)  
  ax.grid(True)  
  return ax  
  
def crop_center(  
  img: torch.Tensor,  
  frac: float = 0.5  
) -> torch.Tensor:  
  r"""  
  从图像中裁剪中心方形。  
  """
  
  h_offset = round(img.shape[0] * (frac / 2))  
  w_offset = round(img.shape[1] * (frac / 2))  
  return img[h_offset:-h_offset, w_offset:-w_offset]  
  
class EarlyStopping:  
  r"""  
  基于适配标准的早期停止辅助器  
  """
  
  def __init__(  
    self,  
    patience: int = 30,  
    margin: float = 1e-4  
  ):  
    self.best_fitness = 0.0  
    self.best_iter = 0  
    self.margin = margin  
    self.patience = patience or float('inf')  # 在epoch停止提高后等待的停止时间  
  
  def __call__(  
    self,  
    iter: int,  
    fitness: float  
  ):  
    r"""  
    检查是否符合停止标准。  
    """
  
    if (fitness - self.best_fitness) > self.margin:  
      self.best_iter = iter  
      self.best_fitness = fitness  
    delta = iter - self.best_iter  
    stop = delta >= self.patience  # 超过耐性则停止训练  
    return stop  
def init_models():  
  r"""  
  为 NeRF 训练初始化模型、编码器和优化器。  
  """
  
  # 编码器  
  encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)  
  encode = lambda x: encoder(x)  
  
  # 视图方向编码  
  if use_viewdirs:  
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,  
                                        log_space=log_space)  
    encode_viewdirs = lambda x: encoder_viewdirs(x)  
    d_viewdirs = encoder_viewdirs.d_output  
  else:  
    encode_viewdirs = None  
    d_viewdirs = None  
  
  # 模型  
  model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,  
              d_viewdirs=d_viewdirs)  
  model.to(device)  
  model_params = list(model.parameters())  
  if use_fine_model:  
    fine_model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,  
                      d_viewdirs=d_viewdirs)  
    fine_model.to(device)  
    model_params = model_params + list(fine_model.parameters())  
  else:  
    fine_model = None  
  
  # 优化器  
  optimizer = torch.optim.Adam(model_params, lr=lr)  
  
  # 早停  
  warmup_stopper = EarlyStopping(patience=50)  
  
  return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper

7.3 训练循环

下面就是具体的训练循环过程函数:

def train():  
  r"""  
  启动 NeRF 训练。  
  """
  
  # 对所有图像进行射线洗牌。  
  if not one_image_per_step:  
    height, width = images.shape[1:3]  
    all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0)  
                        for p in poses[:n_training]], 0)  
    rays_rgb = torch.cat([all_rays, images[:, None]], 1)  
    rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4])  
    rays_rgb = rays_rgb.reshape([-1, 3, 3])  
    rays_rgb = rays_rgb.type(torch.float32)  
    rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]  
    i_batch = 0  
  
  train_psnrs = []  
  val_psnrs = []  
  iternums = []  
  for i in trange(n_iters):  
    model.train()  
  
    if one_image_per_step:  
      # 随机选择一张图片作为目标。  
      target_img_idx = np.random.randint(images.shape[0])  
      target_img = images[target_img_idx].to(device)  
      if center_crop and i         target_img = crop_center(target_img)  
      height, width = target_img.shape[:2]  
      target_pose = poses[target_img_idx].to(device)  
      rays_o, rays_d = get_rays(height, width, focal, target_pose)  
      rays_o = rays_o.reshape([-1, 3])  
      rays_d = rays_d.reshape([-1, 3])  
    else:  
      # 在所有图像上随机显示。  
      batch = rays_rgb[i_batch:i_batch + batch_size]  
      batch = torch.transpose(batch, 0, 1)  
      rays_o, rays_d, target_img = batch  
      height, width = target_img.shape[:2]  
      i_batch += batch_size  
      # 一个epoch后洗牌  
      if i_batch >= rays_rgb.shape[0]:  
          rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]  
          i_batch = 0  
    target_img = target_img.reshape([-1, 3])  
  
    # 运行 TinyNeRF 的一次迭代,得到渲染后的 RGB 图像。  
    outputs = nerf_forward(rays_o, rays_d,  
                           near, far, encode, model,  
                           kwargs_sample_stratified=kwargs_sample_stratified,  
                           n_samples_hierarchical=n_samples_hierarchical,  
                           kwargs_sample_hierarchical=kwargs_sample_hierarchical,  
                           fine_model=fine_model,  
                           viewdirs_encoding_fn=encode_viewdirs,  
                           chunksize=chunksize)  
      
    # 检查任何数字问题。  
    for k, v in outputs.items():  
      if torch.isnan(v).any():  
        print(f"! [Numerical Alert] {k} contains NaN.")  
      if torch.isinf(v).any():  
        print(f"! [Numerical Alert] {k} contains Inf.")  
  
    # 反向传播  
    rgb_predicted = outputs['rgb_map']  
    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)  
    loss.backward()  
    optimizer.step()  
    optimizer.zero_grad()  
    psnr = -10. * torch.log10(loss)  
    train_psnrs.append(psnr.item())  
  
    # 以给定的显示速率评估测试值。  
    if i % display_rate == 0:  
      model.eval()  
      height, width = testimg.shape[:2]  
      rays_o, rays_d = get_rays(height, width, focal, testpose)  
      rays_o = rays_o.reshape([-1, 3])  
      rays_d = rays_d.reshape([-1, 3])  
      outputs = nerf_forward(rays_o, rays_d,  
                             near, far, encode, model,  
                             kwargs_sample_stratified=kwargs_sample_stratified,  
                             n_samples_hierarchical=n_samples_hierarchical,  
                             kwargs_sample_hierarchical=kwargs_sample_hierarchical,  
                             fine_model=fine_model,  
                             viewdirs_encoding_fn=encode_viewdirs,  
                             chunksize=chunksize)  
  
      rgb_predicted = outputs['rgb_map']  
      loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3))  
      print("Loss:", loss.item())  
      val_psnr = -10. * torch.log10(loss)  
      val_psnrs.append(val_psnr.item())  
      iternums.append(i)  
  
      # 绘制输出示例  
      fig, ax = plt.subplots(1, 4, figsize=(24,4), gridspec_kw={'width_ratios': [1, 1, 1, 3]})  
      ax[0].imshow(rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy())  
      ax[0].set_title(f'Iteration: {i}')  
      ax[1].imshow(testimg.detach().cpu().numpy())  
      ax[1].set_title(f'Target')  
      ax[2].plot(range(0, i + 1), train_psnrs, 'r')  
      ax[2].plot(iternums, val_psnrs, 'b')  
      ax[2].set_title('PSNR (train=red, val=blue')  
      z_vals_strat = outputs['z_vals_stratified'].view((-1, n_samples))  
      z_sample_strat = z_vals_strat[z_vals_strat.shape[0] // 2].detach().cpu().numpy()  
      if 'z_vals_hierarchical' in outputs:  
        z_vals_hierarch = outputs['z_vals_hierarchical'].view((-1, n_samples_hierarchical))  
        z_sample_hierarch = z_vals_hierarch[z_vals_hierarch.shape[0] // 2].detach().cpu().numpy()  
      else:  
        z_sample_hierarch = None  
      _ = plot_samples(z_sample_strat, z_sample_hierarch, ax=ax[3])  
      ax[3].margins(0)  
      plt.show()  
  
    # 检查 PSNR 是否存在问题,如果发现问题,则停止运行。  
    if i == warmup_iters - 1:  
      if val_psnr         print(f'Val PSNR {val_psnr} below warmup_min_fitness {warmup_min_fitness}. Stopping...')  
        return False, train_psnrs, val_psnrs  
    elif i       if warmup_stopper is not None and warmup_stopper(i, psnr):  
        print(f'Train PSNR flatlined at {psnr} for {warmup_stopper.patience} iters. Stopping...')  
        return False, train_psnrs, val_psnrs  
      
  return True, train_psnrs, val_psnrs

最终的结果如下图所示:

从0开始用 PyTorch 构建完整的 NeRF
从0开始用 PyTorch 构建完整的 NeRF
▲图6|运行结果示意图©️【深蓝AI】

引用:

[1]https://www.matthewtancik.com/nerf

[2]http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

[3]https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666

[4]https://medium.com/@rparikshat1998/nerf-from-scratch-fe21c08b145d

从0开始用 PyTorch 构建完整的 NeRF

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

从0开始用 PyTorch 构建完整的 NeRF

点击阅读原文进入CV社区

收获更多技术干货


© 版权声明

相关文章

暂无评论

暂无评论...