remove auxiliary loss
This commit is contained in:
parent
e60d030d34
commit
2af283ff67
5 changed files with 25 additions and 293 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -1,3 +1,9 @@
|
||||||
|
# scripts
|
||||||
|
*.sh
|
||||||
|
|
||||||
|
# pycharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|
11
ddp_model.py
11
ddp_model.py
|
@ -63,8 +63,7 @@ class NerfNet(nn.Module):
|
||||||
self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
self.fg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
||||||
input_ch=self.fg_embedder_position.out_dim,
|
input_ch=self.fg_embedder_position.out_dim,
|
||||||
input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
|
input_ch_viewdirs=self.fg_embedder_viewdir.out_dim,
|
||||||
use_viewdirs=args.use_viewdirs,
|
use_viewdirs=args.use_viewdirs)
|
||||||
use_implicit=args.use_implicit)
|
|
||||||
# background; bg_pt is (x, y, z, 1/r)
|
# background; bg_pt is (x, y, z, 1/r)
|
||||||
self.bg_embedder_position = Embedder(input_dim=4,
|
self.bg_embedder_position = Embedder(input_dim=4,
|
||||||
max_freq_log2=args.max_freq_log2 - 1,
|
max_freq_log2=args.max_freq_log2 - 1,
|
||||||
|
@ -75,8 +74,7 @@ class NerfNet(nn.Module):
|
||||||
self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
self.bg_net = MLPNet(D=args.netdepth, W=args.netwidth,
|
||||||
input_ch=self.bg_embedder_position.out_dim,
|
input_ch=self.bg_embedder_position.out_dim,
|
||||||
input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
|
input_ch_viewdirs=self.bg_embedder_viewdir.out_dim,
|
||||||
use_viewdirs=args.use_viewdirs,
|
use_viewdirs=args.use_viewdirs)
|
||||||
use_implicit=args.use_implicit)
|
|
||||||
|
|
||||||
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
|
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals):
|
||||||
'''
|
'''
|
||||||
|
@ -109,7 +107,6 @@ class NerfNet(nn.Module):
|
||||||
T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
|
T = torch.cat((torch.ones_like(T[..., 0:1]), T[..., :-1]), dim=-1) # [..., N_samples]
|
||||||
fg_weights = fg_alpha * T # [..., N_samples]
|
fg_weights = fg_alpha * T # [..., N_samples]
|
||||||
fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
|
fg_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['rgb'], dim=-2) # [..., 3]
|
||||||
fg_diffuse_rgb_map = torch.sum(fg_weights.unsqueeze(-1) * fg_raw['diffuse_rgb'], dim=-2) # [..., 3]
|
|
||||||
fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
|
fg_depth_map = torch.sum(fg_weights * fg_z_vals, dim=-1) # [...,]
|
||||||
|
|
||||||
# render background
|
# render background
|
||||||
|
@ -133,18 +130,14 @@ class NerfNet(nn.Module):
|
||||||
T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
|
T = torch.cat((torch.ones_like(T[..., 0:1]), T), dim=-1) # [..., N_samples]
|
||||||
bg_weights = bg_alpha * T # [..., N_samples]
|
bg_weights = bg_alpha * T # [..., N_samples]
|
||||||
bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
|
bg_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['rgb'], dim=-2) # [..., 3]
|
||||||
bg_diffuse_rgb_map = torch.sum(bg_weights.unsqueeze(-1) * bg_raw['diffuse_rgb'], dim=-2) # [..., 3]
|
|
||||||
bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
|
bg_depth_map = torch.sum(bg_weights * bg_z_vals, dim=-1) # [...,]
|
||||||
|
|
||||||
# composite foreground and background
|
# composite foreground and background
|
||||||
bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
|
bg_rgb_map = bg_lambda.unsqueeze(-1) * bg_rgb_map
|
||||||
bg_diffuse_rgb_map = bg_lambda.unsqueeze(-1) * bg_diffuse_rgb_map
|
|
||||||
bg_depth_map = bg_lambda * bg_depth_map
|
bg_depth_map = bg_lambda * bg_depth_map
|
||||||
rgb_map = fg_rgb_map + bg_rgb_map
|
rgb_map = fg_rgb_map + bg_rgb_map
|
||||||
diffuse_rgb_map = fg_diffuse_rgb_map + bg_diffuse_rgb_map
|
|
||||||
|
|
||||||
ret = OrderedDict([('rgb', rgb_map), # loss
|
ret = OrderedDict([('rgb', rgb_map), # loss
|
||||||
('diffuse_rgb', diffuse_rgb_map), # regularize
|
|
||||||
('fg_weights', fg_weights), # importance sampling
|
('fg_weights', fg_weights), # importance sampling
|
||||||
('bg_weights', bg_weights), # importance sampling
|
('bg_weights', bg_weights), # importance sampling
|
||||||
('fg_rgb', fg_rgb_map), # below are for logging
|
('fg_rgb', fg_rgb_map), # below are for logging
|
||||||
|
|
|
@ -1,22 +1,19 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ddp_model import NerfNet
|
from ddp_model import NerfNet
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from data_loader_split import load_data_split
|
from data_loader_split import load_data_split
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
|
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__package__)
|
logger = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -427,11 +424,6 @@ def ddp_train_nerf(rank, args):
|
||||||
loss = img2mse(ret['rgb'], rgb_gt)
|
loss = img2mse(ret['rgb'], rgb_gt)
|
||||||
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
|
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
|
||||||
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
|
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
|
||||||
# regularize sigma with photo-consistency
|
|
||||||
diffuse_loss = img2mse(ret['diffuse_rgb'], rgb_gt)
|
|
||||||
scalars_to_log['level_{}/diffuse_loss'.format(m)] = diffuse_loss.item()
|
|
||||||
scalars_to_log['level_{}/diffuse_psnr'.format(m)] = mse2psnr(diffuse_loss.item())
|
|
||||||
loss = (1. - args.regularize_weight) * loss + args.regularize_weight * diffuse_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
|
|
||||||
|
@ -571,9 +563,6 @@ def config_parser():
|
||||||
help='apply the trick to avoid fitting to white background')
|
help='apply the trick to avoid fitting to white background')
|
||||||
|
|
||||||
# use implicit
|
# use implicit
|
||||||
parser.add_argument("--use_implicit", action='store_true', help='whether to use implicit regularization')
|
|
||||||
parser.add_argument("--regularize_weight", type=float, default=0.5,
|
|
||||||
help='regularizing weight of auxiliary loss')
|
|
||||||
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
|
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
|
||||||
|
|
||||||
# no training; render only
|
# no training; render only
|
||||||
|
|
237
ddp_test_nerf.py
237
ddp_test_nerf.py
|
@ -1,250 +1,24 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
import torch.multiprocessing
|
import torch.multiprocessing
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ddp_model import NerfNet
|
from ddp_model import NerfNet
|
||||||
import time
|
import time
|
||||||
from data_loader_split import load_data_split
|
from data_loader_split import load_data_split
|
||||||
from utils import mse2psnr, img_HWC2CHW, colorize, colorize_np, TINY_NUMBER, to8b
|
from utils import mse2psnr, colorize_np, to8b
|
||||||
import imageio
|
import imageio
|
||||||
from ddp_run_nerf import config_parser
|
from ddp_run_nerf import config_parser, setup_logger, setup, cleanup, render_single_image
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__package__)
|
logger = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
|
||||||
# create logger
|
|
||||||
logger = logging.getLogger(__package__)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
# create console handler and set level to debug
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
ch.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# create formatter
|
|
||||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
|
||||||
|
|
||||||
# add formatter to ch
|
|
||||||
ch.setFormatter(formatter)
|
|
||||||
|
|
||||||
# add ch to logger
|
|
||||||
logger.addHandler(ch)
|
|
||||||
|
|
||||||
|
|
||||||
def intersect_sphere(ray_o, ray_d):
|
|
||||||
'''
|
|
||||||
ray_o, ray_d: [..., 3]
|
|
||||||
compute the depth of the intersection point between this ray and unit sphere
|
|
||||||
'''
|
|
||||||
# note: d1 becomes negative if this mid point is behind camera
|
|
||||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
|
||||||
p = ray_o + d1.unsqueeze(-1) * ray_d
|
|
||||||
# consider the case where the ray does not intersect the sphere
|
|
||||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
|
||||||
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
|
|
||||||
|
|
||||||
return d1 + d2
|
|
||||||
|
|
||||||
|
|
||||||
def perturb_samples(z_vals):
|
|
||||||
# get intervals between samples
|
|
||||||
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
|
||||||
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
|
|
||||||
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
|
|
||||||
# uniform samples in those intervals
|
|
||||||
t_rand = torch.rand_like(z_vals)
|
|
||||||
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
|
|
||||||
|
|
||||||
return z_vals
|
|
||||||
|
|
||||||
|
|
||||||
def sample_pdf(bins, weights, N_samples, det=False):
|
|
||||||
'''
|
|
||||||
:param bins: tensor of shape [..., M+1], M is the number of bins
|
|
||||||
:param weights: tensor of shape [..., M]
|
|
||||||
:param N_samples: number of samples along each ray
|
|
||||||
:param det: if True, will perform deterministic sampling
|
|
||||||
:return: [..., N_samples]
|
|
||||||
'''
|
|
||||||
# Get pdf
|
|
||||||
weights = weights + TINY_NUMBER # prevent nans
|
|
||||||
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
|
|
||||||
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
|
|
||||||
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
|
|
||||||
|
|
||||||
# Take uniform samples
|
|
||||||
dots_sh = list(weights.shape[:-1])
|
|
||||||
M = weights.shape[-1]
|
|
||||||
|
|
||||||
min_cdf = 0.00
|
|
||||||
max_cdf = 1.00 # prevent outlier samples
|
|
||||||
|
|
||||||
if det:
|
|
||||||
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
|
|
||||||
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
|
|
||||||
else:
|
|
||||||
sh = dots_sh + [N_samples]
|
|
||||||
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
|
|
||||||
|
|
||||||
# Invert CDF
|
|
||||||
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
|
|
||||||
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
|
|
||||||
|
|
||||||
# random sample inside each bin
|
|
||||||
below_inds = torch.clamp(above_inds-1, min=0)
|
|
||||||
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
|
||||||
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
|
||||||
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
# fix numeric issue
|
|
||||||
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
|
|
||||||
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
|
|
||||||
t = (u - cdf_g[..., 0]) / denom
|
|
||||||
|
|
||||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
|
|
||||||
##### parallel rendering of a single image
|
|
||||||
ray_batch = ray_sampler.get_all()
|
|
||||||
# split into ranks; make sure different processes don't overlap
|
|
||||||
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
|
|
||||||
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
|
|
||||||
for key in ray_batch:
|
|
||||||
if torch.is_tensor(ray_batch[key]):
|
|
||||||
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
|
|
||||||
|
|
||||||
# split into chunks and render inside each process
|
|
||||||
ray_batch_split = OrderedDict()
|
|
||||||
for key in ray_batch:
|
|
||||||
if torch.is_tensor(ray_batch[key]):
|
|
||||||
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
|
|
||||||
|
|
||||||
# forward and backward
|
|
||||||
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
|
|
||||||
for s in range(len(ray_batch_split['ray_d'])):
|
|
||||||
ray_o = ray_batch_split['ray_o'][s]
|
|
||||||
ray_d = ray_batch_split['ray_d'][s]
|
|
||||||
min_depth = ray_batch_split['min_depth'][s]
|
|
||||||
|
|
||||||
dots_sh = list(ray_d.shape[:-1])
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
net = models['net_{}'.format(m)]
|
|
||||||
# sample depths
|
|
||||||
N_samples = models['cascade_samples'][m]
|
|
||||||
if m == 0:
|
|
||||||
# foreground depth
|
|
||||||
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
|
|
||||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
|
||||||
fg_near_depth = min_depth # [..., 3]
|
|
||||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
|
||||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
|
||||||
|
|
||||||
# background depth
|
|
||||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
|
||||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
|
||||||
|
|
||||||
# delete unused memory
|
|
||||||
del fg_near_depth
|
|
||||||
del step
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
fg_weights = ret['fg_weights'].clone().detach()
|
|
||||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
|
||||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
|
||||||
N_samples=N_samples, det=True) # [..., N_samples]
|
|
||||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
bg_weights = ret['bg_weights'].clone().detach()
|
|
||||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
|
||||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
|
||||||
N_samples=N_samples, det=True) # [..., N_samples]
|
|
||||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
# delete unused memory
|
|
||||||
del fg_weights
|
|
||||||
del fg_depth_mid
|
|
||||||
del fg_depth_samples
|
|
||||||
del bg_weights
|
|
||||||
del bg_depth_mid
|
|
||||||
del bg_depth_samples
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
|
|
||||||
|
|
||||||
for key in ret:
|
|
||||||
if key not in ['fg_weights', 'bg_weights']:
|
|
||||||
if torch.is_tensor(ret[key]):
|
|
||||||
if key not in ret_merge_chunk[m]:
|
|
||||||
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
|
|
||||||
else:
|
|
||||||
ret_merge_chunk[m][key].append(ret[key].cpu())
|
|
||||||
|
|
||||||
ret[key] = None
|
|
||||||
|
|
||||||
# clean unused memory
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# merge results from different chunks
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
|
|
||||||
|
|
||||||
# merge results from different processes
|
|
||||||
if rank == 0:
|
|
||||||
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
# generate tensors to store results from other processes
|
|
||||||
sh = list(ret_merge_chunk[m][key].shape[1:])
|
|
||||||
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
|
|
||||||
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
|
|
||||||
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
|
|
||||||
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
|
|
||||||
# print(m, key, ret_merge_rank[m][key].shape)
|
|
||||||
else: # send results to main process
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
torch.distributed.gather(ret_merge_chunk[m][key])
|
|
||||||
|
|
||||||
# only rank 0 program returns
|
|
||||||
if rank == 0:
|
|
||||||
return ret_merge_rank
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def setup(rank, world_size):
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = '12355'
|
|
||||||
# initialize the process group
|
|
||||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
def ddp_test_nerf(rank, args):
|
def ddp_test_nerf(rank, args):
|
||||||
###### set up multi-processing
|
###### set up multi-processing
|
||||||
setup(rank, args.world_size)
|
setup(rank, args.world_size)
|
||||||
|
@ -328,7 +102,6 @@ def ddp_test_nerf(rank, args):
|
||||||
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
||||||
dt = time.time() - time0
|
dt = time.time() - time0
|
||||||
if rank == 0: # only main process should do this
|
if rank == 0: # only main process should do this
|
||||||
|
|
||||||
logger.info('Rendered {} in {} seconds'.format(fname, dt))
|
logger.info('Rendered {} in {} seconds'.format(fname, dt))
|
||||||
|
|
||||||
# only save last level
|
# only save last level
|
||||||
|
@ -342,10 +115,6 @@ def ddp_test_nerf(rank, args):
|
||||||
im = to8b(im)
|
im = to8b(im)
|
||||||
imageio.imwrite(os.path.join(out_dir, fname), im)
|
imageio.imwrite(os.path.join(out_dir, fname), im)
|
||||||
|
|
||||||
# im = ret[-1]['diffuse_rgb'].numpy()
|
|
||||||
# im = to8b(im)
|
|
||||||
# imageio.imwrite(os.path.join(out_dir, 'diffuse_' + fname), im)
|
|
||||||
|
|
||||||
im = ret[-1]['fg_rgb'].numpy()
|
im = ret[-1]['fg_rgb'].numpy()
|
||||||
im = to8b(im)
|
im = to8b(im)
|
||||||
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
|
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__package__)
|
logger = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
|
||||||
class Embedder(nn.Module):
|
class Embedder(nn.Module):
|
||||||
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
||||||
log_sampling=True, include_input=True,
|
log_sampling=True, include_input=True,
|
||||||
|
@ -67,8 +68,8 @@ def weights_init(m):
|
||||||
|
|
||||||
|
|
||||||
class MLPNet(nn.Module):
|
class MLPNet(nn.Module):
|
||||||
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3, skips=[4], use_viewdirs=False
|
def __init__(self, D=8, W=256, input_ch=3, input_ch_viewdirs=3,
|
||||||
, use_implicit=False):
|
skips=[4], use_viewdirs=False):
|
||||||
'''
|
'''
|
||||||
:param D: network depth
|
:param D: network depth
|
||||||
:param W: network width
|
:param W: network width
|
||||||
|
@ -79,10 +80,6 @@ class MLPNet(nn.Module):
|
||||||
'''
|
'''
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_implicit = use_implicit
|
|
||||||
if self.use_implicit:
|
|
||||||
logger.info('Using implicit regularization as well!')
|
|
||||||
|
|
||||||
self.input_ch = input_ch
|
self.input_ch = input_ch
|
||||||
self.input_ch_viewdirs = input_ch_viewdirs
|
self.input_ch_viewdirs = input_ch_viewdirs
|
||||||
self.use_viewdirs = use_viewdirs
|
self.use_viewdirs = use_viewdirs
|
||||||
|
@ -104,35 +101,21 @@ class MLPNet(nn.Module):
|
||||||
self.sigma_layers = nn.Sequential(*sigma_layers)
|
self.sigma_layers = nn.Sequential(*sigma_layers)
|
||||||
# self.sigma_layers.apply(weights_init) # xavier init
|
# self.sigma_layers.apply(weights_init) # xavier init
|
||||||
|
|
||||||
base_dim = dim
|
# rgb color
|
||||||
# diffuse color
|
rgb_layers = []
|
||||||
diffuse_rgb_layers = []
|
|
||||||
dim = base_dim
|
|
||||||
for i in range(1):
|
|
||||||
diffuse_rgb_layers.append(nn.Linear(dim, W))
|
|
||||||
diffuse_rgb_layers.append(nn.ReLU())
|
|
||||||
dim = W
|
|
||||||
diffuse_rgb_layers.append(nn.Linear(dim, 3))
|
|
||||||
diffuse_rgb_layers.append(nn.Sigmoid())
|
|
||||||
self.diffuse_rgb_layers = nn.Sequential(*diffuse_rgb_layers)
|
|
||||||
# self.diffuse_rgb_layers.apply(weights_init)
|
|
||||||
|
|
||||||
# specular color
|
|
||||||
specular_rgb_layers = []
|
|
||||||
dim = base_dim
|
|
||||||
base_remap_layers = [nn.Linear(dim, 256), ]
|
base_remap_layers = [nn.Linear(dim, 256), ]
|
||||||
self.base_remap_layers = nn.Sequential(*base_remap_layers)
|
self.base_remap_layers = nn.Sequential(*base_remap_layers)
|
||||||
# self.base_remap_layers.apply(weights_init)
|
# self.base_remap_layers.apply(weights_init)
|
||||||
|
|
||||||
dim = 256 + self.input_ch_viewdirs
|
dim = 256 + self.input_ch_viewdirs
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
specular_rgb_layers.append(nn.Linear(dim, W))
|
rgb_layers.append(nn.Linear(dim, W))
|
||||||
specular_rgb_layers.append(nn.ReLU())
|
rgb_layers.append(nn.ReLU())
|
||||||
dim = W
|
dim = W
|
||||||
specular_rgb_layers.append(nn.Linear(dim, 3))
|
rgb_layers.append(nn.Linear(dim, 3))
|
||||||
specular_rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
|
rgb_layers.append(nn.Sigmoid()) # rgb values are normalized to [0, 1]
|
||||||
self.specular_rgb_layers = nn.Sequential(*specular_rgb_layers)
|
self.rgb_layers = nn.Sequential(*rgb_layers)
|
||||||
# self.specular_rgb_layers.apply(weights_init)
|
# self.rgb_layers.apply(weights_init)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
'''
|
'''
|
||||||
|
@ -150,18 +133,10 @@ class MLPNet(nn.Module):
|
||||||
sigma = self.sigma_layers(base)
|
sigma = self.sigma_layers(base)
|
||||||
sigma = torch.abs(sigma)
|
sigma = torch.abs(sigma)
|
||||||
|
|
||||||
diffuse_rgb = self.diffuse_rgb_layers(base)
|
|
||||||
|
|
||||||
base_remap = self.base_remap_layers(base)
|
base_remap = self.base_remap_layers(base)
|
||||||
input_viewdirs = input[..., -self.input_ch_viewdirs:]
|
input_viewdirs = input[..., -self.input_ch_viewdirs:]
|
||||||
specular_rgb = self.specular_rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
|
rgb = self.rgb_layers(torch.cat((base_remap, input_viewdirs), dim=-1))
|
||||||
|
|
||||||
if self.use_implicit:
|
|
||||||
rgb = specular_rgb
|
|
||||||
else:
|
|
||||||
rgb = diffuse_rgb + specular_rgb
|
|
||||||
|
|
||||||
ret = OrderedDict([('rgb', rgb),
|
ret = OrderedDict([('rgb', rgb),
|
||||||
('diffuse_rgb', diffuse_rgb),
|
|
||||||
('sigma', sigma.squeeze(-1))])
|
('sigma', sigma.squeeze(-1))])
|
||||||
return ret
|
return ret
|
||||||
|
|
Loading…
Reference in a new issue