nerf_plus_plus/nerf_network.py

143 lines
4.9 KiB
Python
Raw Normal View History

2020-10-12 03:33:31 +02:00
import torch
import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
from collections import OrderedDict
import logging
logger = logging.getLogger(__package__)
2020-10-12 04:11:56 +02:00
2020-10-12 03:33:31 +02:00
class Embedder(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):
'''
:param input_dim: dimension of input to be embedded
:param max_freq_log2: log2 of max freq; min freq is 1 by default
:param N_freqs: number of frequency bands
:param log_sampling: if True, frequency bands are linerly sampled in log-space
:param include_input: if True, raw input is included in the embedding
:param periodic_fns: periodic functions used to embed input
'''
super().__init__()
self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns
self.out_dim = 0
if self.include_input:
self.out_dim += self.input_dim
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
if log_sampling:
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
self.freq_bands = self.freq_bands.numpy().tolist()
def forward(self, input):
'''
:param input: tensor of shape [..., self.input_dim]
:return: tensor of shape [..., self.out_dim]
'''
assert (input.shape[-1] == self.input_dim)
out = []
if self.include_input:
out.append(input)
for i in range(len(self.freq_bands)):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))
out = torch.cat(out, dim=-1)
assert (out.shape[-1] == self.out_dim)
return out
# default tensorflow initialization of linear layers
def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.zeros_(m.bias.data)
class MLPNet(nn.Module):
2020-10-12 04:11:56 +02:00
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3,
skips=[4], use_viewdirs=False):
2020-10-12 03:33:31 +02:00
'''
:param D: network depth
:param W: network width
:param input_ch: input channels for encodings of (x, y, z)
:param input_ch_viewdirs: input channels for encodings of view directions
:param skips: skip connection in network
:param use_viewdirs: if True, will use the view directions as input
'''
super().__init__()
self.input_ch = input_ch
self.input_ch_viewdirs = input_ch_viewdirs
self.use_viewdirs = use_viewdirs
self.skips = skips
self.base_layers = []
dim = self.input_ch
for i in range(D):
self.base_layers.append(
nn.Sequential(nn.Linear(dim, W), nn.ReLU())
)
dim = W
if i in self.skips and i != (D-1): # skip connection after i^th layer
dim += input_ch
self.base_layers = nn.ModuleList(self.base_layers)
# self.base_layers.apply(weights_init) # xavier init
sigma_layers = [nn.Linear(dim, 1), ] # sigma must be positive
self.sigma_layers = nn.Sequential(*sigma_layers)
# self.sigma_layers.apply(weights_init) # xavier init
2020-10-12 04:11:56 +02:00
# rgb color
rgb_layers = []
2020-10-12 03:33:31 +02:00
base_remap_layers = [nn.Linear(dim, 256), ]
self.base_remap_layers = nn.Sequential(*base_remap_layers)
# self.base_remap_layers.apply(weights_init)
dim = 256 + self.input_ch_viewdirs
for i in range(1):
2020-10-12 04:11:56 +02:00
rgb_layers.append(nn.Linear(dim, W))
rgb_layers.append(nn.ReLU())
2020-10-12 03:33:31 +02:00
dim = W
2020-10-12 04:11:56 +02:00
rgb_layers.append(nn.Linear(dim, 3))
rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
self.rgb_layers = nn.Sequential(*rgb_layers)
# self.rgb_layers.apply(weights_init)
2020-10-12 03:33:31 +02:00
def forward(self, input):
'''
:param input: [..., input_ch+input_ch_viewdirs]
:return [..., 4]
'''
input_pts = input[..., :self.input_ch]
base = self.base_layers[0](input_pts)
for i in range(len(self.base_layers)-1):
if i in self.skips:
base = torch.cat((input_pts, base), dim=-1)
base = self.base_layers[i+1](base)
sigma = self.sigma_layers(base)
sigma = torch.abs(sigma)
base_remap = self.base_remap_layers(base)
input_viewdirs = input[..., -self.input_ch_viewdirs:]
2020-10-12 04:11:56 +02:00
rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
2020-10-12 03:33:31 +02:00
ret = OrderedDict([('rgb', rgb),
('sigma', sigma.squeeze(-1))])
return ret