nerf_plus_plus/ddp_model.py
2020-10-11 22:11:56 -04:00

148 lines
7.7 KiB
Python

import torch
import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
from utils import TINY_NUMBER, HUGE_NUMBER
from collections import OrderedDict
from nerf_network import Embedder, MLPNet
######################################################################################
# wrapper to simplify the use of nerfnet
######################################################################################
def depth2pts_outside(ray_o, ray_d, depth):
'''
ray_o, ray_d: [..., 3]
depth: [...]; inverse of distance to sphere origin
'''
# note: d1 becomes negative if this mid point is behind camera
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
p_mid = ray_o + d1.unsqueeze(-1) * ray_d
p_mid_norm = torch.norm(p_mid, dim=-1)
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
phi = torch.asin(p_mid_norm)
theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
# now rotate p_sphere
# Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
p_sphere_new = p_sphere * torch.cos(rot_angle) + \
torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
# now calculate conventional depth
depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
return pts, depth_real
class NerfNet(nn.Module):
def __init__(self, args):
'''
:param D: network depth
:param W: network width
:param input_ch: input channels for encodings of (x, y, z)
:param input_ch_viewdirs: input channels for encodings of view directions
:param skips: skip connection in network
:param use_viewdirs: if True, will use the view directions as input
'''
super().__init__()
# foreground
self.fg_embedder_position = Embedder(input_dim=3,
max_freq_log2=args.max_freq_log2 - 1,
N_freqs=args.max_freq_log2)
self.fg_embedder_viewdir = Embedder(input_dim=3,
max_freq_log2=args.max_freq_log2_viewdirs - 1,
N_freqs=args.max_freq_log2_viewdirs)
self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
input_ch=self.fg_embedder_position.out_dim,
input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
use_viewdirs=args.use_viewdirs)
# background; bg_pt is (x, y, z, 1/r)
self.bg_embedder_position = Embedder(input_dim=4,
max_freq_log2=args.max_freq_log2 - 1,
N_freqs=args.max_freq_log2)
self.bg_embedder_viewdir = Embedder(input_dim=3,
max_freq_log2=args.max_freq_log2_viewdirs - 1,
N_freqs=args.max_freq_log2_viewdirs)
self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
input_ch=self.bg_embedder_position.out_dim,
input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
use_viewdirs=args.use_viewdirs)
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
'''
:param ray_o, ray_d: [..., 3]
:param fg_z_max: [...,]
:param fg_z_vals, bg_z_vals: [..., N_samples]
:return
'''
# print(ray_o.shape, ray_d.shape, fg_z_max.shape, fg_z_vals.shape, bg_z_vals.shape)
ray_d_norm = torch.norm(ray_d, dim=-1, keepdim=True) # [..., 1]
viewdirs = ray_d / ray_d_norm # [..., 3]
dots_sh = list(ray_d.shape[:-1])
######### render foreground
N_samples = fg_z_vals.shape[-1]
fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
fg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
fg_pts = fg_ray_o + fg_z_vals.unsqueeze(-1) * fg_ray_d
input = torch.cat((self.fg_embedder_position(fg_pts),
self.fg_embedder_viewdir(fg_viewdirs)), dim=-1)
fg_raw = self.fg_net(input)
# alpha blending
fg_dists = fg_z_vals[..., 1:] - fg_z_vals[..., :-1]
# account for view directions
fg_dists = ray_d_norm * torch.cat((fg_dists, fg_z_max.unsqueeze(-1) - fg_z_vals[..., -1:]), dim=-1) # [..., N_samples]
fg_alpha = 1. - torch.exp(-fg_raw['sigma'] * fg_dists) # [..., N_samples]
T = torch.cumprod(1. - fg_alpha + TINY_NUMBER, dim=-1) # [..., N_samples]
bg_lambda = T[..., -1]
T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
fg_weights = fg_alpha * T # [..., N_samples]
fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
# render background
N_samples = bg_z_vals.shape[-1]
bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
bg_viewdirs = viewdirs.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
bg_pts, _ = depth2pts_outside(bg_ray_o, bg_ray_d, bg_z_vals) # [..., N_samples, 4]
input = torch.cat((self.bg_embedder_position(bg_pts),
self.bg_embedder_viewdir(bg_viewdirs)), dim=-1)
# near_depth: physical far; far_depth: physical near
input = torch.flip(input, dims=[-2,])
bg_z_vals = torch.flip(bg_z_vals, dims=[-1,]) # 1--->0
bg_dists = bg_z_vals[..., :-1] - bg_z_vals[..., 1:]
bg_dists = torch.cat((bg_dists, HUGE_NUMBER * torch.ones_like(bg_dists[..., 0:1])), dim=-1) # [..., N_samples]
bg_raw = self.bg_net(input)
bg_alpha = 1. - torch.exp(-bg_raw['sigma'] * bg_dists) # [..., N_samples]
# Eq. (3): T
# maths show weights, and summation of weights along a ray, are always inside [0, 1]
T = torch.cumprod(1. - bg_alpha + TINY_NUMBER, dim=-1)[..., :-1] # [..., N_samples-1]
T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
bg_weights = bg_alpha * T # [..., N_samples]
bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
# composite foreground and background
bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
bg_depth_map = bg_lambda * bg_depth_map
rgb_map = fg_rgb_map + bg_rgb_map
ret = OrderedDict([('rgb', rgb_map), # loss
('fg_weights', fg_weights), # importance sampling
('bg_weights', bg_weights), # importance sampling
('fg_rgb', fg_rgb_map), # below are for logging
('fg_depth', fg_depth_map),
('bg_rgb', bg_rgb_map),
('bg_depth', bg_depth_map),
('bg_lambda', bg_lambda)])
return ret