142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
# import torch.nn.functional as F
|
|
# import numpy as np
|
|
from collections import OrderedDict
|
|
|
|
import logging
|
|
logger = logging.getLogger(__package__)
|
|
|
|
|
|
class Embedder(nn.Module):
|
|
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
|
log_sampling=True, include_input=True,
|
|
periodic_fns=(torch.sin, torch.cos)):
|
|
'''
|
|
:param input_dim: dimension of input to be embedded
|
|
:param max_freq_log2: log2 of max freq; min freq is 1 by default
|
|
:param N_freqs: number of frequency bands
|
|
:param log_sampling: if True, frequency bands are linerly sampled in log-space
|
|
:param include_input: if True, raw input is included in the embedding
|
|
:param periodic_fns: periodic functions used to embed input
|
|
'''
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.include_input = include_input
|
|
self.periodic_fns = periodic_fns
|
|
|
|
self.out_dim = 0
|
|
if self.include_input:
|
|
self.out_dim += self.input_dim
|
|
|
|
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
|
|
|
if log_sampling:
|
|
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
|
else:
|
|
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
|
|
|
|
self.freq_bands = self.freq_bands.numpy().tolist()
|
|
|
|
def forward(self, input):
|
|
'''
|
|
:param input: tensor of shape [..., self.input_dim]
|
|
:return: tensor of shape [..., self.out_dim]
|
|
'''
|
|
assert (input.shape[-1] == self.input_dim)
|
|
|
|
out = []
|
|
if self.include_input:
|
|
out.append(input)
|
|
|
|
for i in range(len(self.freq_bands)):
|
|
freq = self.freq_bands[i]
|
|
for p_fn in self.periodic_fns:
|
|
out.append(p_fn(input * freq))
|
|
out = torch.cat(out, dim=-1)
|
|
|
|
assert (out.shape[-1] == self.out_dim)
|
|
return out
|
|
|
|
# default tensorflow initialization of linear layers
|
|
def weights_init(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight.data)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias.data)
|
|
|
|
|
|
class MLPNet(nn.Module):
|
|
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3,
|
|
skips=[4], use_viewdirs=False):
|
|
'''
|
|
:param D: network depth
|
|
:param W: network width
|
|
:param input_ch: input channels for encodings of (x, y, z)
|
|
:param input_ch_viewdirs: input channels for encodings of view directions
|
|
:param skips: skip connection in network
|
|
:param use_viewdirs: if True, will use the view directions as input
|
|
'''
|
|
super().__init__()
|
|
|
|
self.input_ch = input_ch
|
|
self.input_ch_viewdirs = input_ch_viewdirs
|
|
self.use_viewdirs = use_viewdirs
|
|
self.skips = skips
|
|
|
|
self.base_layers = []
|
|
dim = self.input_ch
|
|
for i in range(D):
|
|
self.base_layers.append(
|
|
nn.Sequential(nn.Linear(dim, W), nn.ReLU())
|
|
)
|
|
dim = W
|
|
if i in self.skips and i != (D-1): # skip connection after i^th layer
|
|
dim += input_ch
|
|
self.base_layers = nn.ModuleList(self.base_layers)
|
|
# self.base_layers.apply(weights_init) # xavier init
|
|
|
|
sigma_layers = [nn.Linear(dim, 1), ] # sigma must be positive
|
|
self.sigma_layers = nn.Sequential(*sigma_layers)
|
|
# self.sigma_layers.apply(weights_init) # xavier init
|
|
|
|
# rgb color
|
|
rgb_layers = []
|
|
base_remap_layers = [nn.Linear(dim, 256), ]
|
|
self.base_remap_layers = nn.Sequential(*base_remap_layers)
|
|
# self.base_remap_layers.apply(weights_init)
|
|
|
|
dim = 256 + self.input_ch_viewdirs
|
|
for i in range(1):
|
|
rgb_layers.append(nn.Linear(dim, W))
|
|
rgb_layers.append(nn.ReLU())
|
|
dim = W
|
|
rgb_layers.append(nn.Linear(dim, 3))
|
|
rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
|
|
self.rgb_layers = nn.Sequential(*rgb_layers)
|
|
# self.rgb_layers.apply(weights_init)
|
|
|
|
def forward(self, input):
|
|
'''
|
|
:param input: [..., input_ch+input_ch_viewdirs]
|
|
:return [..., 4]
|
|
'''
|
|
input_pts = input[..., :self.input_ch]
|
|
|
|
base = self.base_layers[0](input_pts)
|
|
for i in range(len(self.base_layers)-1):
|
|
if i in self.skips:
|
|
base = torch.cat((input_pts, base), dim=-1)
|
|
base = self.base_layers[i+1](base)
|
|
|
|
sigma = self.sigma_layers(base)
|
|
sigma = torch.abs(sigma)
|
|
|
|
base_remap = self.base_remap_layers(base)
|
|
input_viewdirs = input[..., -self.input_ch_viewdirs:]
|
|
rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
|
|
|
|
ret = OrderedDict([('rgb', rgb),
|
|
('sigma', sigma.squeeze(-1))])
|
|
return ret
|