diff --git a/.gitignore b/.gitignore index 5391d87..57170e6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ +# scripts +*.sh + +# pycharm +.idea/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/ddp_model.py b/ddp_model.py index d10ff0e..1d98f05 100644 --- a/ddp_model.py +++ b/ddp_model.py @@ -63,8 +63,7 @@ class NerfNet(nn.Module): self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth, input_ch=self.fg_embedder_position.out_dim, input_ch_viewdirs=self.fg_embedder_viewdir.out_dim, - use_viewdirs=args.use_viewdirs, - use_implicit=args.use_implicit) + use_viewdirs=args.use_viewdirs) # background; bg_pt is (x, y, z, 1/r) self.bg_embedder_position = Embedder(input_dim=4, max_freq_log2=args.max_freq_log2 - 1, @@ -75,8 +74,7 @@ class NerfNet(nn.Module): self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth, input_ch=self.bg_embedder_position.out_dim, input_ch_viewdirs=self.bg_embedder_viewdir.out_dim, - use_viewdirs=args.use_viewdirs, - use_implicit=args.use_implicit) + use_viewdirs=args.use_viewdirs) def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals): ''' @@ -109,7 +107,6 @@ class NerfNet(nn.Module): T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples] fg_weights = fg_alpha * T # [..., N_samples] fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3] - fg_diffuse_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['diffuse_rgb'], dim=-2) # [..., 3] fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,] # render background @@ -133,18 +130,14 @@ class NerfNet(nn.Module): T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples] bg_weights = bg_alpha * T # [..., N_samples] bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3] - bg_diffuse_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['diffuse_rgb'], dim=-2) # [..., 3] bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,] # composite foreground and background bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map - bg_diffuse_rgb_map = bg_lambda.unsqueeze(-1) * bg_diffuse_rgb_map bg_depth_map = bg_lambda * bg_depth_map rgb_map = fg_rgb_map + bg_rgb_map - diffuse_rgb_map = fg_diffuse_rgb_map + bg_diffuse_rgb_map ret = OrderedDict([('rgb', rgb_map), # loss - ('diffuse_rgb', diffuse_rgb_map), # regularize ('fg_weights', fg_weights), # importance sampling ('bg_weights', bg_weights), # importance sampling ('fg_rgb', fg_rgb_map), # below are for logging diff --git a/ddp_run_nerf.py b/ddp_run_nerf.py index 8c8d5bb..c1fdbe3 100644 --- a/ddp_run_nerf.py +++ b/ddp_run_nerf.py @@ -1,22 +1,19 @@ import torch -import torch.nn as nn +# import torch.nn as nn import torch.optim import torch.distributed from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing - import os from collections import OrderedDict from ddp_model import NerfNet import time - from data_loader_split import load_data_split - import numpy as np from tensorboardX import SummaryWriter from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER - import logging + logger = logging.getLogger(__package__) @@ -427,11 +424,6 @@ def ddp_train_nerf(rank, args): loss = img2mse(ret['rgb'], rgb_gt) scalars_to_log['level_{}/loss'.format(m)] = loss.item() scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item()) - # regularize sigma with photo-consistency - diffuse_loss = img2mse(ret['diffuse_rgb'], rgb_gt) - scalars_to_log['level_{}/diffuse_loss'.format(m)] = diffuse_loss.item() - scalars_to_log['level_{}/diffuse_psnr'.format(m)] = mse2psnr(diffuse_loss.item()) - loss = (1. - args.regularize_weight) * loss + args.regularize_weight * diffuse_loss loss.backward() optim.step() @@ -571,9 +563,6 @@ def config_parser(): help='apply the trick to avoid fitting to white background') # use implicit - parser.add_argument("--use_implicit", action='store_true', help='whether to use implicit regularization') - parser.add_argument("--regularize_weight", type=float, default=0.5, - help='regularizing weight of auxiliary loss') parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth') # no training; render only diff --git a/ddp_test_nerf.py b/ddp_test_nerf.py index 1b25043..266d657 100644 --- a/ddp_test_nerf.py +++ b/ddp_test_nerf.py @@ -1,250 +1,24 @@ import torch -import torch.nn as nn +# import torch.nn as nn import torch.optim import torch.distributed from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing import numpy as np - import os from collections import OrderedDict from ddp_model import NerfNet import time from data_loader_split import load_data_split -from utils import mse2psnr, img_HWC2CHW, colorize, colorize_np, TINY_NUMBER, to8b +from utils import mse2psnr, colorize_np, to8b import imageio -from ddp_run_nerf import config_parser - +from ddp_run_nerf import config_parser, setup_logger, setup, cleanup, render_single_image import logging logger = logging.getLogger(__package__) -def setup_logger(): - # create logger - logger = logging.getLogger(__package__) - logger.setLevel(logging.DEBUG) - - # create console handler and set level to debug - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - - # create formatter - formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') - - # add formatter to ch - ch.setFormatter(formatter) - - # add ch to logger - logger.addHandler(ch) - - -def intersect_sphere(ray_o, ray_d): - ''' - ray_o, ray_d: [..., 3] - compute the depth of the intersection point between this ray and unit sphere - ''' - # note: d1 becomes negative if this mid point is behind camera - d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) - p = ray_o + d1.unsqueeze(-1) * ray_d - # consider the case where the ray does not intersect the sphere - ray_d_cos = 1. / torch.norm(ray_d, dim=-1) - d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos - - return d1 + d2 - - -def perturb_samples(z_vals): - # get intervals between samples - mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) - upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) - lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) - # uniform samples in those intervals - t_rand = torch.rand_like(z_vals) - z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] - - return z_vals - - -def sample_pdf(bins, weights, N_samples, det=False): - ''' - :param bins: tensor of shape [..., M+1], M is the number of bins - :param weights: tensor of shape [..., M] - :param N_samples: number of samples along each ray - :param det: if True, will perform deterministic sampling - :return: [..., N_samples] - ''' - # Get pdf - weights = weights + TINY_NUMBER # prevent nans - pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] - cdf = torch.cumsum(pdf, dim=-1) # [..., M] - cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1] - - # Take uniform samples - dots_sh = list(weights.shape[:-1]) - M = weights.shape[-1] - - min_cdf = 0.00 - max_cdf = 1.00 # prevent outlier samples - - if det: - u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) - u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples] - else: - sh = dots_sh + [N_samples] - u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples] - - # Invert CDF - # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] - above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long() - - # random sample inside each bin - below_inds = torch.clamp(above_inds-1, min=0) - inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2] - - cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2] - - bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2] - - # fix numeric issue - denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] - denom = torch.where(denom