You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

606 lines
28 KiB
Python

4 years ago
import torch
4 years ago
import torch.nn as nn
4 years ago
import torch.optim
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing
import os
from collections import OrderedDict
4 years ago
from ddp_model import NerfNetWithAutoExpo
4 years ago
import time
from data_loader_split import load_data_split
import numpy as np
from tensorboardX import SummaryWriter
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
import logging
4 years ago
import json
4 years ago
logger = logging.getLogger(__package__)
def setup_logger():
# create logger
logger = logging.getLogger(__package__)
# logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
def intersect_sphere(ray_o, ray_d):
'''
ray_o, ray_d: [..., 3]
compute the depth of the intersection point between this ray and unit sphere
'''
# note: d1 becomes negative if this mid point is behind camera
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
p = ray_o + d1.unsqueeze(-1) * ray_d
# consider the case where the ray does not intersect the sphere
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
return d1 + d2
def perturb_samples(z_vals):
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
# uniform samples in those intervals
t_rand = torch.rand_like(z_vals)
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
return z_vals
def sample_pdf(bins, weights, N_samples, det=False):
'''
:param bins: tensor of shape [..., M+1], M is the number of bins
:param weights: tensor of shape [..., M]
:param N_samples: number of samples along each ray
:param det: if True, will perform deterministic sampling
:return: [..., N_samples]
'''
# Get pdf
weights = weights + TINY_NUMBER # prevent nans
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
# Take uniform samples
dots_sh = list(weights.shape[:-1])
M = weights.shape[-1]
min_cdf = 0.00
max_cdf = 1.00 # prevent outlier samples
if det:
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
else:
sh = dots_sh + [N_samples]
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
# Invert CDF
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
# random sample inside each bin
below_inds = torch.clamp(above_inds-1, min=0)
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
# fix numeric issue
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
return samples
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
##### parallel rendering of a single image
ray_batch = ray_sampler.get_all()
# split into ranks; make sure different processes don't overlap
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
# split into chunks and render inside each process
ray_batch_split = OrderedDict()
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
# forward and backward
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
for s in range(len(ray_batch_split['ray_d'])):
ray_o = ray_batch_split['ray_o'][s]
ray_d = ray_batch_split['ray_d'][s]
min_depth = ray_batch_split['min_depth'][s]
dots_sh = list(ray_d.shape[:-1])
for m in range(models['cascade_level']):
net = models['net_{}'.format(m)]
# sample depths
N_samples = models['cascade_samples'][m]
if m == 0:
# foreground depth
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
fg_near_depth = min_depth # [..., 3]
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
# background depth
bg_depth = torch.linspace(0., 1., N_samples).view(
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
# delete unused memory
del fg_near_depth
del step
torch.cuda.empty_cache()
else:
# sample pdf and concat with earlier samples
fg_weights = ret['fg_weights'].clone().detach()
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
N_samples=N_samples, det=True) # [..., N_samples]
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
# sample pdf and concat with earlier samples
bg_weights = ret['bg_weights'].clone().detach()
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
N_samples=N_samples, det=True) # [..., N_samples]
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
# delete unused memory
del fg_weights
del fg_depth_mid
del fg_depth_samples
del bg_weights
del bg_depth_mid
del bg_depth_samples
torch.cuda.empty_cache()
with torch.no_grad():
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
for key in ret:
if key not in ['fg_weights', 'bg_weights']:
if torch.is_tensor(ret[key]):
if key not in ret_merge_chunk[m]:
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
else:
ret_merge_chunk[m][key].append(ret[key].cpu())
ret[key] = None
# clean unused memory
torch.cuda.empty_cache()
# merge results from different chunks
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
# merge results from different processes
if rank == 0:
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
# generate tensors to store results from other processes
sh = list(ret_merge_chunk[m][key].shape[1:])
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
# print(m, key, ret_merge_rank[m][key].shape)
else: # send results to main process
for m in range(len(ret_merge_chunk)):
for key in ret_merge_chunk[m]:
torch.distributed.gather(ret_merge_chunk[m][key])
# only rank 0 program returns
if rank == 0:
return ret_merge_rank
else:
return None
def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=''):
rgb_im = img_HWC2CHW(torch.from_numpy(gt_img))
writer.add_image(prefix + 'rgb_gt', rgb_im, global_step)
for m in range(len(log_data)):
rgb_im = img_HWC2CHW(log_data[m]['rgb'])
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step)
rgb_im = img_HWC2CHW(log_data[m]['fg_rgb'])
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step)
depth = log_data[m]['fg_depth']
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
mask=mask))
writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step)
rgb_im = img_HWC2CHW(log_data[m]['bg_rgb'])
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step)
depth = log_data[m]['bg_depth']
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
mask=mask))
writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step)
bg_lambda = log_data[m]['bg_lambda']
bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True,
mask=mask))
writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
# port = np.random.randint(12355, 12399)
# os.environ['MASTER_PORT'] = '{}'.format(port)
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
torch.distributed.destroy_process_group()
4 years ago
def create_nerf(rank, args):
4 years ago
###### create network and wrap in ddp; each process should do this
# fix random seed just to make sure the network is initialized with same weights at different processes
torch.manual_seed(777)
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
torch.cuda.set_device(rank)
models = OrderedDict()
models['cascade_level'] = args.cascade_level
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
for m in range(models['cascade_level']):
4 years ago
img_names = None
if args.optim_autoexpo:
# load training image names for autoexposure
f = os.path.join(args.basedir, args.expname, 'train_images.json')
with open(f) as file:
img_names = json.load(file)
net = NerfNetWithAutoExpo(args, optim_autoexpo=args.optim_autoexpo, img_names=img_names).to(rank)
net = DDP(net, device_ids=[rank], output_device=rank, find_unused_parameters=True)
# net = DDP(net, device_ids=[rank], output_device=rank)
4 years ago
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
models['net_{}'.format(m)] = net
models['optim_{}'.format(m)] = optim
start = -1
###### load pretrained weights; each process should do this
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
ckpts = [args.ckpt_path]
else:
ckpts = [os.path.join(args.basedir, args.expname, f)
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
def path2iter(path):
tmp = os.path.basename(path)[:-4]
idx = tmp.rfind('_')
return int(tmp[idx + 1:])
ckpts = sorted(ckpts, key=path2iter)
logger.info('Found ckpts: {}'.format(ckpts))
if len(ckpts) > 0 and not args.no_reload:
fpath = ckpts[-1]
logger.info('Reloading from: {}'.format(fpath))
start = path2iter(fpath)
# configure map_location properly for different processes
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
to_load = torch.load(fpath, map_location=map_location)
for m in range(models['cascade_level']):
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
models[name].load_state_dict(to_load[name])
models[name].load_state_dict(to_load[name])
4 years ago
return start, models
def ddp_train_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory))
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
logger.info('setting batch size according to 24G gpu')
args.N_rand = 1024
args.chunk_size = 8192
else:
logger.info('setting batch size according to 12G gpu')
args.N_rand = 512
args.chunk_size = 4096
###### Create log dir and copy the config file
if rank == 0:
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
f = os.path.join(args.basedir, args.expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(args.basedir, args.expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
torch.distributed.barrier()
ray_samplers = load_data_split(args.datadir, args.scene, split='train',
try_load_min_depth=args.load_min_depth)
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation',
try_load_min_depth=args.load_min_depth, skip=args.testskip)
# write training image names for autoexposure
if args.optim_autoexpo:
f = os.path.join(args.basedir, args.expname, 'train_images.json')
with open(f, 'w') as file:
img_names = [ray_samplers[i].img_path for i in range(len(ray_samplers))]
json.dump(img_names, file, indent=2)
###### create network and wrap in ddp; each process should do this
start, models = create_nerf(rank, args)
4 years ago
##### important!!!
# make sure different processes sample different rays
np.random.seed((rank + 1) * 777)
# make sure different processes have different perturbations in depth samples
torch.manual_seed((rank + 1) * 777)
##### only main process should do the logging
if rank == 0:
writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname))
# start training
what_val_to_log = 0 # helper variable for parallel rendering of a image
what_train_to_log = 0
for global_step in range(start+1, start+1+args.N_iters):
time0 = time.time()
scalars_to_log = OrderedDict()
### Start of core optimization loop
scalars_to_log['resolution'] = ray_samplers[0].resolution_level
# randomly sample rays and move to device
i = np.random.randint(low=0, high=len(ray_samplers))
ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False)
for key in ray_batch:
if torch.is_tensor(ray_batch[key]):
ray_batch[key] = ray_batch[key].to(rank)
# forward and backward
dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays
all_rets = [] # results on different cascade levels
for m in range(models['cascade_level']):
optim = models['optim_{}'.format(m)]
net = models['net_{}'.format(m)]
# sample depths
N_samples = models['cascade_samples'][m]
if m == 0:
# foreground depth
fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,]
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
fg_near_depth = ray_batch['min_depth'] # [..., 3]
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
fg_depth = perturb_samples(fg_depth) # random perturbation during training
# background depth
bg_depth = torch.linspace(0., 1., N_samples).view(
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
bg_depth = perturb_samples(bg_depth) # random perturbation during training
else:
# sample pdf and concat with earlier samples
fg_weights = ret['fg_weights'].clone().detach()
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
N_samples=N_samples, det=False) # [..., N_samples]
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
# sample pdf and concat with earlier samples
bg_weights = ret['bg_weights'].clone().detach()
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
N_samples=N_samples, det=False) # [..., N_samples]
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
optim.zero_grad()
4 years ago
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name'])
4 years ago
all_rets.append(ret)
rgb_gt = ray_batch['rgb'].to(rank)
4 years ago
if 'autoexpo' in ret:
scale, shift = ret['autoexpo']
scalars_to_log['level_{}/autoexpo_scale'.format(m)] = scale.item()
scalars_to_log['level_{}/autoexpo_shift'.format(m)] = shift.item()
# rgb_gt = scale * rgb_gt + shift
rgb_pred = (ret['rgb'] - shift) / scale
rgb_loss = img2mse(rgb_pred, rgb_gt)
loss = rgb_loss + args.lambda_autoexpo * (torch.abs(scale-1.)+torch.abs(shift))
else:
rgb_loss = img2mse(ret['rgb'], rgb_gt)
loss = rgb_loss
scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item()
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(rgb_loss.item())
4 years ago
loss.backward()
optim.step()
# # clean unused memory
# torch.cuda.empty_cache()
### end of core optimization loop
dt = time.time() - time0
scalars_to_log['iter_time'] = dt
### only main process should do the logging
if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
logstr = '{} step: {} '.format(args.expname, global_step)
for k in scalars_to_log:
logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
writer.add_scalar(k, scalars_to_log[k], global_step)
logger.info(logstr)
### each process should do this; but only main process merges the results
if global_step % args.i_img == 0 or global_step == start+1:
#### critical: make sure each process is working on the same random image
time0 = time.time()
idx = what_val_to_log % len(val_ray_samplers)
log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size)
what_val_to_log += 1
dt = time.time() - time0
if rank == 0: # only main process should do this
logger.info('Logged a random validation view in {} seconds'.format(dt))
log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/')
time0 = time.time()
idx = what_train_to_log % len(ray_samplers)
log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
what_train_to_log += 1
dt = time.time() - time0
if rank == 0: # only main process should do this
logger.info('Logged a random training view in {} seconds'.format(dt))
log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/')
4 years ago
del log_data
4 years ago
torch.cuda.empty_cache()
if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
# saving checkpoints and logging
fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step))
to_save = OrderedDict()
for m in range(models['cascade_level']):
name = 'net_{}'.format(m)
to_save[name] = models[name].state_dict()
name = 'optim_{}'.format(m)
to_save[name] = models[name].state_dict()
torch.save(to_save, fpath)
# clean up for multi-processing
cleanup()
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, help='config file path')
parser.add_argument("--expname", type=str, help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
# dataset options
parser.add_argument("--datadir", type=str, default=None, help='input data directory')
parser.add_argument("--scene", type=str, default=None, help='scene name')
parser.add_argument("--testskip", type=int, default=8,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
# model size
parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
# checkpoints
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
parser.add_argument("--ckpt_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
# batch size
parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--chunk_size", type=int, default=1024 * 8,
help='number of rays processed in parallel, decrease if running out of memory')
# iterations
parser.add_argument("--N_iters", type=int, default=250001,
help='number of iterations')
4 years ago
# render only
4 years ago
parser.add_argument("--render_splits", type=str, default='test',
help='splits to render')
# cascade training
parser.add_argument("--cascade_level", type=int, default=2,
help='number of cascade levels')
parser.add_argument("--cascade_samples", type=str, default='64,64',
help='samples at each level')
parser.add_argument("--devices", type=str, default='0,1',
help='cuda device for each level')
parser.add_argument("--bg_devices", type=str, default='0,2',
help='cuda device for the background of each level')
4 years ago
# multiprocess learning
4 years ago
parser.add_argument("--world_size", type=int, default='-1',
help='number of processes')
4 years ago
# optimize autoexposure
parser.add_argument("--optim_autoexpo", action='store_true',
help='optimize autoexposure parameters')
parser.add_argument("--lambda_autoexpo", type=float, default=1., help='regularization weight for autoexposure')
4 years ago
# learning rate options
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
help='decay learning rate by a factor every specified number of steps')
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
help='decay learning rate by a factor every specified number of steps')
# rendering options
parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
parser.add_argument("--max_freq_log2", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
return parser
def train():
parser = config_parser()
args = parser.parse_args()
logger.info(parser.format_values())
if args.world_size == -1:
args.world_size = torch.cuda.device_count()
logger.info('Using # gpus: {}'.format(args.world_size))
torch.multiprocessing.spawn(ddp_train_nerf,
args=(args,),
nprocs=args.world_size,
join=True)
if __name__ == '__main__':
setup_logger()
train()