remove auxiliary loss

This commit is contained in:
Kai-46 2020-10-11 22:11:56 -04:00
parent e60d030d34
commit 2af283ff67
5 changed files with 25 additions and 293 deletions

6
.gitignore vendored
View file

@ -1,3 +1,9 @@
# scripts
*.sh
# pycharm
.idea/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View file

@ -63,8 +63,7 @@ class NerfNet(nn.Module):
self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth, self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
input_ch=self.fg_embedder_position.out_dim, input_ch=self.fg_embedder_position.out_dim,
input_ch_viewdirs=self.fg_embedder_viewdir.out_dim, input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
use_viewdirs=args.use_viewdirs, use_viewdirs=args.use_viewdirs)
use_implicit=args.use_implicit)
# background; bg_pt is (x, y, z, 1/r) # background; bg_pt is (x, y, z, 1/r)
self.bg_embedder_position = Embedder(input_dim=4, self.bg_embedder_position = Embedder(input_dim=4,
max_freq_log2=args.max_freq_log2 - 1, max_freq_log2=args.max_freq_log2 - 1,
@ -75,8 +74,7 @@ class NerfNet(nn.Module):
self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth, self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
input_ch=self.bg_embedder_position.out_dim, input_ch=self.bg_embedder_position.out_dim,
input_ch_viewdirs=self.bg_embedder_viewdir.out_dim, input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
use_viewdirs=args.use_viewdirs, use_viewdirs=args.use_viewdirs)
use_implicit=args.use_implicit)
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals): def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
''' '''
@ -109,7 +107,6 @@ class NerfNet(nn.Module):
T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples] T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
fg_weights = fg_alpha * T # [..., N_samples] fg_weights = fg_alpha * T # [..., N_samples]
fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3] fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
fg_diffuse_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['diffuse_rgb'], dim=-2) # [..., 3]
fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,] fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
# render background # render background
@ -133,18 +130,14 @@ class NerfNet(nn.Module):
T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples] T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
bg_weights = bg_alpha * T # [..., N_samples] bg_weights = bg_alpha * T # [..., N_samples]
bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3] bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
bg_diffuse_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['diffuse_rgb'], dim=-2) # [..., 3]
bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,] bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
# composite foreground and background # composite foreground and background
bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
bg_diffuse_rgb_map = bg_lambda.unsqueeze(-1) * bg_diffuse_rgb_map
bg_depth_map = bg_lambda * bg_depth_map bg_depth_map = bg_lambda * bg_depth_map
rgb_map = fg_rgb_map + bg_rgb_map rgb_map = fg_rgb_map + bg_rgb_map
diffuse_rgb_map = fg_diffuse_rgb_map + bg_diffuse_rgb_map
ret = OrderedDict([('rgb', rgb_map), # loss ret = OrderedDict([('rgb', rgb_map), # loss
('diffuse_rgb', diffuse_rgb_map), # regularize
('fg_weights', fg_weights), # importance sampling ('fg_weights', fg_weights), # importance sampling
('bg_weights', bg_weights), # importance sampling ('bg_weights', bg_weights), # importance sampling
('fg_rgb', fg_rgb_map), # below are for logging ('fg_rgb', fg_rgb_map), # below are for logging

View file

@ -1,22 +1,19 @@
import torch import torch
import torch.nn as nn # import torch.nn as nn
import torch.optim import torch.optim
import torch.distributed import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing import torch.multiprocessing
import os import os
from collections import OrderedDict from collections import OrderedDict
from ddp_model import NerfNet from ddp_model import NerfNet
import time import time
from data_loader_split import load_data_split from data_loader_split import load_data_split
import numpy as np import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
import logging import logging
logger = logging.getLogger(__package__) logger = logging.getLogger(__package__)
@ -427,11 +424,6 @@ def ddp_train_nerf(rank, args):
loss = img2mse(ret['rgb'], rgb_gt) loss = img2mse(ret['rgb'], rgb_gt)
scalars_to_log['level_{}/loss'.format(m)] = loss.item() scalars_to_log['level_{}/loss'.format(m)] = loss.item()
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item()) scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
# regularize sigma with photo-consistency
diffuse_loss = img2mse(ret['diffuse_rgb'], rgb_gt)
scalars_to_log['level_{}/diffuse_loss'.format(m)] = diffuse_loss.item()
scalars_to_log['level_{}/diffuse_psnr'.format(m)] = mse2psnr(diffuse_loss.item())
loss = (1. - args.regularize_weight) * loss + args.regularize_weight * diffuse_loss
loss.backward() loss.backward()
optim.step() optim.step()
@ -571,9 +563,6 @@ def config_parser():
help='apply the trick to avoid fitting to white background') help='apply the trick to avoid fitting to white background')
# use implicit # use implicit
parser.add_argument("--use_implicit", action='store_true', help='whether to use implicit regularization')
parser.add_argument("--regularize_weight", type=float, default=0.5,
help='regularizing weight of auxiliary loss')
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth') parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
# no training; render only # no training; render only

View file

@ -1,250 +1,24 @@
import torch import torch
import torch.nn as nn # import torch.nn as nn
import torch.optim import torch.optim
import torch.distributed import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing import torch.multiprocessing
import numpy as np import numpy as np
import os import os
from collections import OrderedDict from collections import OrderedDict
from ddp_model import NerfNet from ddp_model import NerfNet
import time import time
from data_loader_split import load_data_split from data_loader_split import load_data_split
from utils import mse2psnr, img_HWC2CHW, colorize, colorize_np, TINY_NUMBER, to8b from utils import mse2psnr, colorize_np, to8b
import imageio import imageio
from ddp_run_nerf import config_parser from ddp_run_nerf import config_parser, setup_logger, setup, cleanup, render_single_image
import logging import logging
logger = logging.getLogger(__package__) logger = logging.getLogger(__package__)
def setup_logger():
# create logger
logger = logging.getLogger(__package__)
logger.setLevel(logging.DEBUG)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
def intersect_sphere(ray_o, ray_d):
'''
ray_o, ray_d: [..., 3]
compute the depth of the intersection point between this ray and unit sphere
'''
# note: d1 becomes negative if this mid point is behind camera
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
p = ray_o + d1.unsqueeze(-1) * ray_d
# consider the case where the ray does not intersect the sphere
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
return d1 + d2
def perturb_samples(z_vals):
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
# uniform samples in those intervals
t_rand = torch.rand_like(z_vals)
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
return z_vals
def sample_pdf(bins, weights, N_samples, det=False):
'''
:param bins: tensor of shape [..., M+1], M is the number of bins
:param weights: tensor of shape [..., M]
:param N_samples: number of samples along each ray
:param det: if True, will perform deterministic sampling
:return: [..., N_samples]
'''
# Get pdf
weights = weights + TINY_NUMBER # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
# Take uniform samples
dots_sh = list(weights.shape[:-1])
M = weights.shape[-1]
min_cdf = 0.00
max_cdf = 1.00 # prevent outlier samples
if det:
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
else:
sh = dots_sh + [N_samples]
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
# Invert CDF
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
# random sample inside each bin
below_inds = torch.clamp(above_inds-1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
# fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
return samples
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
##### parallel rendering of a single image
ray_batch = ray_sampler.get_all()
# split into ranks; make sure different processes don't overlap
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
# split into chunks and render inside each process
ray_batch_split = OrderedDict()
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
# forward and backward
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
for s in range(len(ray_batch_split['ray_d'])):
ray_o = ray_batch_split['ray_o'][s]
ray_d = ray_batch_split['ray_d'][s]
min_depth = ray_batch_split['min_depth'][s]
dots_sh = list(ray_d.shape[:-1])
for m in range(models['cascade_level']):
net = models['net_{}'.format(m)]
# sample depths
N_samples = models['cascade_samples'][m]
if m == 0:
# foreground depth
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
fg_near_depth = min_depth # [..., 3]
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
# background depth
bg_depth = torch.linspace(0., 1., N_samples).view(
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
# delete unused memory
del fg_near_depth
del step
torch.cuda.empty_cache()
else:
# sample pdf and concat with earlier samples
fg_weights = ret['fg_weights'].clone().detach()
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
N_samples=N_samples, det=True) # [..., N_samples]
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
# sample pdf and concat with earlier samples
bg_weights = ret['bg_weights'].clone().detach()
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
N_samples=N_samples, det=True) # [..., N_samples]
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
# delete unused memory
del fg_weights
del fg_depth_mid
del fg_depth_samples
del bg_weights
del bg_depth_mid
del bg_depth_samples
torch.cuda.empty_cache()
with torch.no_grad():
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
for key in ret:
if key not in ['fg_weights', 'bg_weights']:
if torch.is_tensor(ret[key]):
if key not in ret_merge_chunk[m]:
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
else:
ret_merge_chunk[m][key].append(ret[key].cpu())
ret[key] = None
# clean unused memory
torch.cuda.empty_cache()
# merge results from different chunks
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
# merge results from different processes
if rank == 0:
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
# generate tensors to store results from other processes
sh = list(ret_merge_chunk[m][key].shape[1:])
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
# print(m, key, ret_merge_rank[m][key].shape)
else: # send results to main process
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
torch.distributed.gather(ret_merge_chunk[m][key])
# only rank 0 program returns
if rank == 0:
return ret_merge_rank
else:
return None
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
torch.distributed.destroy_process_group()
def ddp_test_nerf(rank, args): def ddp_test_nerf(rank, args):
###### set up multi-processing ###### set up multi-processing
setup(rank, args.world_size) setup(rank, args.world_size)
@ -328,7 +102,6 @@ def ddp_test_nerf(rank, args):
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
dt = time.time() - time0 dt = time.time() - time0
if rank == 0: # only main process should do this if rank == 0: # only main process should do this
logger.info('Rendered {} in {} seconds'.format(fname, dt)) logger.info('Rendered {} in {} seconds'.format(fname, dt))
# only save last level # only save last level
@ -342,10 +115,6 @@ def ddp_test_nerf(rank, args):
im = to8b(im) im = to8b(im)
imageio.imwrite(os.path.join(out_dir, fname), im) imageio.imwrite(os.path.join(out_dir, fname), im)
# im = ret[-1]['diffuse_rgb'].numpy()
# im = to8b(im)
# imageio.imwrite(os.path.join(out_dir, 'diffuse_' + fname), im)
im = ret[-1]['fg_rgb'].numpy() im = ret[-1]['fg_rgb'].numpy()
im = to8b(im) im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im) imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)

View file

@ -7,6 +7,7 @@ from collections import OrderedDict
import logging import logging
logger = logging.getLogger(__package__) logger = logging.getLogger(__package__)
class Embedder(nn.Module): class Embedder(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs, def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True, log_sampling=True, include_input=True,
@ -67,8 +68,8 @@ def weights_init(m):
class MLPNet(nn.Module): class MLPNet(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3, skips=[4], use_viewdirs=False def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3,
, use_implicit=False): skips=[4], use_viewdirs=False):
''' '''
:param D: network depth :param D: network depth
:param W: network width :param W: network width
@ -79,10 +80,6 @@ class MLPNet(nn.Module):
''' '''
super().__init__() super().__init__()
self.use_implicit = use_implicit
if self.use_implicit:
logger.info('Using implicit regularization as well!')
self.input_ch = input_ch self.input_ch = input_ch
self.input_ch_viewdirs = input_ch_viewdirs self.input_ch_viewdirs = input_ch_viewdirs
self.use_viewdirs = use_viewdirs self.use_viewdirs = use_viewdirs
@ -104,35 +101,21 @@ class MLPNet(nn.Module):
self.sigma_layers = nn.Sequential(*sigma_layers) self.sigma_layers = nn.Sequential(*sigma_layers)
# self.sigma_layers.apply(weights_init) # xavier init # self.sigma_layers.apply(weights_init) # xavier init
base_dim = dim # rgb color
# diffuse color rgb_layers = []
diffuse_rgb_layers = []
dim = base_dim
for i in range(1):
diffuse_rgb_layers.append(nn.Linear(dim, W))
diffuse_rgb_layers.append(nn.ReLU())
dim = W
diffuse_rgb_layers.append(nn.Linear(dim, 3))
diffuse_rgb_layers.append(nn.Sigmoid())
self.diffuse_rgb_layers = nn.Sequential(*diffuse_rgb_layers)
# self.diffuse_rgb_layers.apply(weights_init)
# specular color
specular_rgb_layers = []
dim = base_dim
base_remap_layers = [nn.Linear(dim, 256), ] base_remap_layers = [nn.Linear(dim, 256), ]
self.base_remap_layers = nn.Sequential(*base_remap_layers) self.base_remap_layers = nn.Sequential(*base_remap_layers)
# self.base_remap_layers.apply(weights_init) # self.base_remap_layers.apply(weights_init)
dim = 256 + self.input_ch_viewdirs dim = 256 + self.input_ch_viewdirs
for i in range(1): for i in range(1):
specular_rgb_layers.append(nn.Linear(dim, W)) rgb_layers.append(nn.Linear(dim, W))
specular_rgb_layers.append(nn.ReLU()) rgb_layers.append(nn.ReLU())
dim = W dim = W
specular_rgb_layers.append(nn.Linear(dim, 3)) rgb_layers.append(nn.Linear(dim, 3))
specular_rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1] rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
self.specular_rgb_layers = nn.Sequential(*specular_rgb_layers) self.rgb_layers = nn.Sequential(*rgb_layers)
# self.specular_rgb_layers.apply(weights_init) # self.rgb_layers.apply(weights_init)
def forward(self, input): def forward(self, input):
''' '''
@ -150,18 +133,10 @@ class MLPNet(nn.Module):
sigma = self.sigma_layers(base) sigma = self.sigma_layers(base)
sigma = torch.abs(sigma) sigma = torch.abs(sigma)
diffuse_rgb = self.diffuse_rgb_layers(base)
base_remap = self.base_remap_layers(base) base_remap = self.base_remap_layers(base)
input_viewdirs = input[..., -self.input_ch_viewdirs:] input_viewdirs = input[..., -self.input_ch_viewdirs:]
specular_rgb = self.specular_rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1)) rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
if self.use_implicit:
rgb = specular_rgb
else:
rgb = diffuse_rgb + specular_rgb
ret = OrderedDict([('rgb', rgb), ret = OrderedDict([('rgb', rgb),
('diffuse_rgb', diffuse_rgb),
('sigma', sigma.squeeze(-1))]) ('sigma', sigma.squeeze(-1))])
return ret return ret