From 286d69250f327124e1c3a52da608e407b5ec1dd9 Mon Sep 17 00:00:00 2001 From: Otthorn Date: Thu, 10 Jun 2021 17:03:06 +0200 Subject: [PATCH] isort, black and tqdm --- ddp_train_nerf.py | 865 +++++++++++++++++++++++++++++++++------------- 1 file changed, 616 insertions(+), 249 deletions(-) diff --git a/ddp_train_nerf.py b/ddp_train_nerf.py index 8988318..9a20aac 100644 --- a/ddp_train_nerf.py +++ b/ddp_train_nerf.py @@ -1,20 +1,22 @@ -import torch -import torch.nn as nn -import torch.optim -import torch.distributed -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.multiprocessing +import json +import logging import os -from collections import OrderedDict -from ddp_model import NerfNetWithAutoExpo import time -from data_loader_split import load_data_split +from collections import OrderedDict + import numpy as np +import torch +import torch.distributed +import torch.multiprocessing +import torch.nn as nn +import torch.optim from tensorboardX import SummaryWriter -from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER -import logging -import json +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm +from data_loader_split import load_data_split +from ddp_model import NerfNetWithAutoExpo +from utils import TINY_NUMBER, colorize, img2mse, img_HWC2CHW, mse2psnr logger = logging.getLogger(__package__) @@ -30,7 +32,9 @@ def setup_logger(): ch.setLevel(logging.DEBUG) # create formatter - formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') + formatter = logging.Formatter( + "%(asctime)s [%(levelname)s] %(name)s: %(message)s" + ) # add formatter to ch ch.setFormatter(formatter) @@ -40,26 +44,29 @@ def setup_logger(): def intersect_sphere(ray_o, ray_d): - ''' + """ ray_o, ray_d: [..., 3] compute the depth of the intersection point between this ray and unit sphere - ''' + """ # note: d1 becomes negative if this mid point is behind camera d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) p = ray_o + d1.unsqueeze(-1) * ray_d # consider the case where the ray does not intersect the sphere - ray_d_cos = 1. / torch.norm(ray_d, dim=-1) + ray_d_cos = 1.0 / torch.norm(ray_d, dim=-1) p_norm_sq = torch.sum(p * p, dim=-1) - if (p_norm_sq >= 1.).any(): - raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!') - d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos + if (p_norm_sq >= 1.0).any(): + raise Exception( + "Not all your cameras are bounded by the unit sphere; " + "please make sure the cameras are normalized properly!" + ) + d2 = torch.sqrt(1.0 - p_norm_sq) * ray_d_cos return d1 + d2 def perturb_samples(z_vals): # get intervals between samples - mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) # uniform samples in those intervals @@ -70,53 +77,76 @@ def perturb_samples(z_vals): def sample_pdf(bins, weights, N_samples, det=False): - ''' + """ :param bins: tensor of shape [..., M+1], M is the number of bins :param weights: tensor of shape [..., M] :param N_samples: number of samples along each ray :param det: if True, will perform deterministic sampling :return: [..., N_samples] - ''' + """ # Get pdf - weights = weights + TINY_NUMBER # prevent nans - pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] - cdf = torch.cumsum(pdf, dim=-1) # [..., M] - cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1] + weights = weights + TINY_NUMBER # prevent nans + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] + cdf = torch.cumsum(pdf, dim=-1) # [..., M] + cdf = torch.cat( + [torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1 + ) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] min_cdf = 0.00 - max_cdf = 1.00 # prevent outlier samples + max_cdf = 1.00 # prevent outlier samples if det: u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) - u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples] + u = u.view([1] * len(dots_sh) + [N_samples]).expand( + dots_sh + + [ + N_samples, + ] + ) # [..., N_samples] else: sh = dots_sh + [N_samples] - u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples] + u = ( + torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf + ) # [..., N_samples] # Invert CDF # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] - above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long() + above_inds = torch.sum( + u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1 + ).long() # random sample inside each bin - below_inds = torch.clamp(above_inds-1, min=0) - inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2] - - cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2] - - bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2] + below_inds = torch.clamp(above_inds - 1, min=0) + inds_g = torch.stack( + (below_inds, above_inds), dim=-1 + ) # [..., N_samples, 2] + + cdf = cdf.unsqueeze(-2).expand( + dots_sh + [N_samples, M + 1] + ) # [..., N_samples, M+1] + cdf_g = torch.gather( + input=cdf, dim=-1, index=inds_g + ) # [..., N_samples, 2] + + bins = bins.unsqueeze(-2).expand( + dots_sh + [N_samples, M + 1] + ) # [..., N_samples, M+1] + bins_g = torch.gather( + input=bins, dim=-1, index=inds_g + ) # [..., N_samples, 2] # fix numeric issue - denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] - denom = torch.where(denom1 - writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step) - - rgb_im = img_HWC2CHW(log_data[m]['fg_rgb']) - rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 - writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step) - depth = log_data[m]['fg_depth'] - depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step) - - rgb_im = img_HWC2CHW(log_data[m]['bg_rgb']) - rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 - writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step) - depth = log_data[m]['bg_depth'] - depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step) - bg_lambda = log_data[m]['bg_lambda'] - bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step) + rgb_im = img_HWC2CHW(log_data[m]["rgb"]) + rgb_im = torch.clamp( + rgb_im, min=0.0, max=1.0 + ) # just in case diffuse+specular>1 + writer.add_image( + prefix + "level_{}/rgb".format(m), rgb_im, global_step + ) + + rgb_im = img_HWC2CHW(log_data[m]["fg_rgb"]) + rgb_im = torch.clamp( + rgb_im, min=0.0, max=1.0 + ) # just in case diffuse+specular>1 + writer.add_image( + prefix + "level_{}/fg_rgb".format(m), rgb_im, global_step + ) + depth = log_data[m]["fg_depth"] + depth_im = img_HWC2CHW( + colorize(depth, cmap_name="jet", append_cbar=True, mask=mask) + ) + writer.add_image( + prefix + "level_{}/fg_depth".format(m), depth_im, global_step + ) + + rgb_im = img_HWC2CHW(log_data[m]["bg_rgb"]) + rgb_im = torch.clamp( + rgb_im, min=0.0, max=1.0 + ) # just in case diffuse+specular>1 + writer.add_image( + prefix + "level_{}/bg_rgb".format(m), rgb_im, global_step + ) + depth = log_data[m]["bg_depth"] + depth_im = img_HWC2CHW( + colorize(depth, cmap_name="jet", append_cbar=True, mask=mask) + ) + writer.add_image( + prefix + "level_{}/bg_depth".format(m), depth_im, global_step + ) + bg_lambda = log_data[m]["bg_lambda"] + bg_lambda_im = img_HWC2CHW( + colorize(bg_lambda, cmap_name="hot", append_cbar=True, mask=mask) + ) + writer.add_image( + prefix + "level_{}/bg_lambda".format(m), bg_lambda_im, global_step + ) def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' + os.environ["MASTER_ADDR"] = "localhost" # port = np.random.randint(12355, 12399) # os.environ['MASTER_PORT'] = '{}'.format(port) - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_PORT"] = "12355" # initialize the process group - torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + torch.distributed.init_process_group( + "gloo", rank=rank, world_size=world_size + ) def cleanup(): @@ -291,21 +405,30 @@ def create_nerf(rank, args): torch.cuda.set_device(rank) models = OrderedDict() - models['cascade_level'] = args.cascade_level - models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')] - for m in range(models['cascade_level']): + models["cascade_level"] = args.cascade_level + models["cascade_samples"] = [ + int(x.strip()) for x in args.cascade_samples.split(",") + ] + for m in range(models["cascade_level"]): img_names = None if args.optim_autoexpo: # load training image names for autoexposure - f = os.path.join(args.basedir, args.expname, 'train_images.json') + f = os.path.join(args.basedir, args.expname, "train_images.json") with open(f) as file: img_names = json.load(file) - net = NerfNetWithAutoExpo(args, optim_autoexpo=args.optim_autoexpo, img_names=img_names).to(rank) - net = DDP(net, device_ids=[rank], output_device=rank, find_unused_parameters=True) + net = NerfNetWithAutoExpo( + args, optim_autoexpo=args.optim_autoexpo, img_names=img_names + ).to(rank) + net = DDP( + net, + device_ids=[rank], + output_device=rank, + find_unused_parameters=True, + ) # net = DDP(net, device_ids=[rank], output_device=rank) optim = torch.optim.Adam(net.parameters(), lr=args.lrate) - models['net_{}'.format(m)] = net - models['optim_{}'.format(m)] = optim + models["net_{}".format(m)] = net + models["optim_{}".format(m)] = optim start = -1 @@ -313,23 +436,30 @@ def create_nerf(rank, args): if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)): ckpts = [args.ckpt_path] else: - ckpts = [os.path.join(args.basedir, args.expname, f) - for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')] + ckpts = [ + os.path.join(args.basedir, args.expname, f) + for f in sorted( + os.listdir(os.path.join(args.basedir, args.expname)) + ) + if f.endswith(".pth") + ] + def path2iter(path): tmp = os.path.basename(path)[:-4] - idx = tmp.rfind('_') - return int(tmp[idx + 1:]) + idx = tmp.rfind("_") + return int(tmp[idx + 1 :]) + ckpts = sorted(ckpts, key=path2iter) - logger.info('Found ckpts: {}'.format(ckpts)) + logger.info("Found ckpts: {}".format(ckpts)) if len(ckpts) > 0 and not args.no_reload: fpath = ckpts[-1] - logger.info('Reloading from: {}'.format(fpath)) + logger.info("Reloading from: {}".format(fpath)) start = path2iter(fpath) # configure map_location properly for different processes - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + map_location = {"cuda:%d" % 0: "cuda:%d" % rank} to_load = torch.load(fpath, map_location=map_location) - for m in range(models['cascade_level']): - for name in ['net_{}'.format(m), 'optim_{}'.format(m)]: + for m in range(models["cascade_level"]): + for name in ["net_{}".format(m), "optim_{}".format(m)]: models[name].load_state_dict(to_load[name]) return start, models @@ -343,40 +473,55 @@ def ddp_train_nerf(rank, args): setup_logger() ###### decide chunk size according to gpu memory - logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory)) + logger.info( + "gpu_mem: {}".format( + torch.cuda.get_device_properties(rank).total_memory + ) + ) if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14: - logger.info('setting batch size according to 24G gpu') + logger.info("setting batch size according to 24G gpu") args.N_rand = 1024 args.chunk_size = 8192 else: - logger.info('setting batch size according to 12G gpu') + logger.info("setting batch size according to 12G gpu") args.N_rand = 512 args.chunk_size = 4096 ###### Create log dir and copy the config file if rank == 0: os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) - f = os.path.join(args.basedir, args.expname, 'args.txt') - with open(f, 'w') as file: + f = os.path.join(args.basedir, args.expname, "args.txt") + with open(f, "w") as file: for arg in sorted(vars(args)): attr = getattr(args, arg) - file.write('{} = {}\n'.format(arg, attr)) + file.write("{} = {}\n".format(arg, attr)) if args.config is not None: - f = os.path.join(args.basedir, args.expname, 'config.txt') - with open(f, 'w') as file: - file.write(open(args.config, 'r').read()) + f = os.path.join(args.basedir, args.expname, "config.txt") + with open(f, "w") as file: + file.write(open(args.config, "r").read()) torch.distributed.barrier() - ray_samplers = load_data_split(args.datadir, args.scene, split='train', - try_load_min_depth=args.load_min_depth) - val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation', - try_load_min_depth=args.load_min_depth, skip=args.testskip) + ray_samplers = load_data_split( + args.datadir, + args.scene, + split="train", + try_load_min_depth=args.load_min_depth, + ) + val_ray_samplers = load_data_split( + args.datadir, + args.scene, + split="validation", + try_load_min_depth=args.load_min_depth, + skip=args.testskip, + ) # write training image names for autoexposure if args.optim_autoexpo: - f = os.path.join(args.basedir, args.expname, 'train_images.json') - with open(f, 'w') as file: - img_names = [ray_samplers[i].img_path for i in range(len(ray_samplers))] + f = os.path.join(args.basedir, args.expname, "train_images.json") + with open(f, "w") as file: + img_names = [ + ray_samplers[i].img_path for i in range(len(ray_samplers)) + ] json.dump(img_names, file, indent=2) ###### create network and wrap in ddp; each process should do this @@ -390,79 +535,138 @@ def ddp_train_nerf(rank, args): ##### only main process should do the logging if rank == 0: - writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname)) + writer = SummaryWriter( + os.path.join(args.basedir, "summaries", args.expname) + ) # start training - what_val_to_log = 0 # helper variable for parallel rendering of a image + what_val_to_log = 0 # helper variable for parallel rendering of a image what_train_to_log = 0 - for global_step in range(start+1, start+1+args.N_iters): + for global_step in tqdm(range(start + 1, start + 1 + args.N_iters)): time0 = time.time() scalars_to_log = OrderedDict() ### Start of core optimization loop - scalars_to_log['resolution'] = ray_samplers[0].resolution_level + scalars_to_log["resolution"] = ray_samplers[0].resolution_level # randomly sample rays and move to device i = np.random.randint(low=0, high=len(ray_samplers)) - ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False) + ray_batch = ray_samplers[i].random_sample( + args.N_rand, center_crop=False + ) for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch[key] = ray_batch[key].to(rank) # forward and backward - dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays - all_rets = [] # results on different cascade levels - for m in range(models['cascade_level']): - optim = models['optim_{}'.format(m)] - net = models['net_{}'.format(m)] + dots_sh = list(ray_batch["ray_d"].shape[:-1]) # number of rays + all_rets = [] # results on different cascade levels + for m in range(models["cascade_level"]): + optim = models["optim_{}".format(m)] + net = models["net_{}".format(m)] # sample depths - N_samples = models['cascade_samples'][m] + N_samples = models["cascade_samples"][m] if m == 0: # foreground depth - fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,] - fg_near_depth = ray_batch['min_depth'] # [..., ] + fg_far_depth = intersect_sphere( + ray_batch["ray_o"], ray_batch["ray_d"] + ) # [...,] + fg_near_depth = ray_batch["min_depth"] # [..., ] step = (fg_far_depth - fg_near_depth) / (N_samples - 1) - fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] - fg_depth = perturb_samples(fg_depth) # random perturbation during training + fg_depth = torch.stack( + [fg_near_depth + i * step for i in range(N_samples)], + dim=-1, + ) # [..., N_samples] + fg_depth = perturb_samples( + fg_depth + ) # random perturbation during training # background depth - bg_depth = torch.linspace(0., 1., N_samples).view( - [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank) - bg_depth = perturb_samples(bg_depth) # random perturbation during training + bg_depth = ( + torch.linspace(0.0, 1.0, N_samples) + .view( + [ + 1, + ] + * len(dots_sh) + + [ + N_samples, + ] + ) + .expand( + dots_sh + + [ + N_samples, + ] + ) + .to(rank) + ) + bg_depth = perturb_samples( + bg_depth + ) # random perturbation during training else: # sample pdf and concat with earlier samples - fg_weights = ret['fg_weights'].clone().detach() - fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1] - fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] - fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights, - N_samples=N_samples, det=False) # [..., N_samples] - fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1)) + fg_weights = ret["fg_weights"].clone().detach() + fg_depth_mid = 0.5 * ( + fg_depth[..., 1:] + fg_depth[..., :-1] + ) # [..., N_samples-1] + fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] + fg_depth_samples = sample_pdf( + bins=fg_depth_mid, + weights=fg_weights, + N_samples=N_samples, + det=False, + ) # [..., N_samples] + fg_depth, _ = torch.sort( + torch.cat((fg_depth, fg_depth_samples), dim=-1) + ) # sample pdf and concat with earlier samples - bg_weights = ret['bg_weights'].clone().detach() - bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) - bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] - bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights, - N_samples=N_samples, det=False) # [..., N_samples] - bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1)) + bg_weights = ret["bg_weights"].clone().detach() + bg_depth_mid = 0.5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) + bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] + bg_depth_samples = sample_pdf( + bins=bg_depth_mid, + weights=bg_weights, + N_samples=N_samples, + det=False, + ) # [..., N_samples] + bg_depth, _ = torch.sort( + torch.cat((bg_depth, bg_depth_samples), dim=-1) + ) optim.zero_grad() - ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name']) + ret = net( + ray_batch["ray_o"], + ray_batch["ray_d"], + fg_far_depth, + fg_depth, + bg_depth, + img_name=ray_batch["img_name"], + ) all_rets.append(ret) - rgb_gt = ray_batch['rgb'].to(rank) - if 'autoexpo' in ret: - scale, shift = ret['autoexpo'] - scalars_to_log['level_{}/autoexpo_scale'.format(m)] = scale.item() - scalars_to_log['level_{}/autoexpo_shift'.format(m)] = shift.item() + rgb_gt = ray_batch["rgb"].to(rank) + if "autoexpo" in ret: + scale, shift = ret["autoexpo"] + scalars_to_log[ + "level_{}/autoexpo_scale".format(m) + ] = scale.item() + scalars_to_log[ + "level_{}/autoexpo_shift".format(m) + ] = shift.item() # rgb_gt = scale * rgb_gt + shift - rgb_pred = (ret['rgb'] - shift) / scale + rgb_pred = (ret["rgb"] - shift) / scale rgb_loss = img2mse(rgb_pred, rgb_gt) - loss = rgb_loss + args.lambda_autoexpo * (torch.abs(scale-1.)+torch.abs(shift)) + loss = rgb_loss + args.lambda_autoexpo * ( + torch.abs(scale - 1.0) + torch.abs(shift) + ) else: - rgb_loss = img2mse(ret['rgb'], rgb_gt) + rgb_loss = img2mse(ret["rgb"], rgb_gt) loss = rgb_loss - scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item() - scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(rgb_loss.item()) + scalars_to_log["level_{}/loss".format(m)] = rgb_loss.item() + scalars_to_log["level_{}/pnsr".format(m)] = mse2psnr( + rgb_loss.item() + ) loss.backward() optim.step() @@ -471,49 +675,85 @@ def ddp_train_nerf(rank, args): ### end of core optimization loop dt = time.time() - time0 - scalars_to_log['iter_time'] = dt + scalars_to_log["iter_time"] = dt ### only main process should do the logging if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): - logstr = '{} step: {} '.format(args.expname, global_step) + logstr = "{} step: {} ".format(args.expname, global_step) for k in scalars_to_log: - logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) + logstr += " {}: {:.6f}".format(k, scalars_to_log[k]) writer.add_scalar(k, scalars_to_log[k], global_step) logger.info(logstr) ### each process should do this; but only main process merges the results - if global_step % args.i_img == 0 or global_step == start+1: + if global_step % args.i_img == 0 or global_step == start + 1: #### critical: make sure each process is working on the same random image time0 = time.time() idx = what_val_to_log % len(val_ray_samplers) - log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size) + log_data = render_single_image( + rank, + args.world_size, + models, + val_ray_samplers[idx], + args.chunk_size, + ) what_val_to_log += 1 dt = time.time() - time0 - if rank == 0: # only main process should do this - logger.info('Logged a random validation view in {} seconds'.format(dt)) - log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/') + if rank == 0: # only main process should do this + logger.info( + "Logged a random validation view in {} seconds".format(dt) + ) + log_view_to_tb( + writer, + global_step, + log_data, + gt_img=val_ray_samplers[idx].get_img(), + mask=None, + prefix="val/", + ) time0 = time.time() idx = what_train_to_log % len(ray_samplers) - log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) + log_data = render_single_image( + rank, + args.world_size, + models, + ray_samplers[idx], + args.chunk_size, + ) what_train_to_log += 1 dt = time.time() - time0 - if rank == 0: # only main process should do this - logger.info('Logged a random training view in {} seconds'.format(dt)) - log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/') + if rank == 0: # only main process should do this + logger.info( + "Logged a random training view in {} seconds".format(dt) + ) + log_view_to_tb( + writer, + global_step, + log_data, + gt_img=ray_samplers[idx].get_img(), + mask=None, + prefix="train/", + ) del log_data torch.cuda.empty_cache() - if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0): + if rank == 0 and ( + global_step % args.i_weights == 0 and global_step > 0 + ): # saving checkpoints and logging - fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step)) + fpath = os.path.join( + args.basedir, + args.expname, + "model_{:06d}.pth".format(global_step), + ) to_save = OrderedDict() - for m in range(models['cascade_level']): - name = 'net_{}'.format(m) + for m in range(models["cascade_level"]): + name = "net_{}".format(m) to_save[name] = models[name].state_dict() - name = 'optim_{}'.format(m) + name = "optim_{}".format(m) to_save[name] = models[name].state_dict() torch.save(to_save, fpath) @@ -523,64 +763,194 @@ def ddp_train_nerf(rank, args): def config_parser(): import configargparse + parser = configargparse.ArgumentParser() - parser.add_argument('--config', is_config_file=True, help='config file path') - parser.add_argument("--expname", type=str, help='experiment name') - parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') + parser.add_argument( + "--config", + is_config_file=True, + help="config file path", + ) + parser.add_argument( + "--expname", + type=str, + help="experiment name", + ) + parser.add_argument( + "--basedir", + type=str, + default="./logs/", + help="where to store ckpts and logs", + ) # dataset options - parser.add_argument("--datadir", type=str, default=None, help='input data directory') - parser.add_argument("--scene", type=str, default=None, help='scene name') - parser.add_argument("--testskip", type=int, default=8, - help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') + parser.add_argument( + "--datadir", + type=str, + default=None, + help="input data directory", + ) + parser.add_argument( + "--scene", + type=str, + default=None, + help="scene name", + ) + parser.add_argument( + "--testskip", + type=int, + default=8, + help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", + ) # model size - parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network') - parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network') - parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') + parser.add_argument( + "--netdepth", + type=int, + default=8, + help="layers in coarse network", + ) + parser.add_argument( + "--netwidth", + type=int, + default=256, + help="channels per layer in coarse network", + ) + parser.add_argument( + "--use_viewdirs", + action="store_true", + help="use full 5D input instead of 3D", + ) # checkpoints - parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') - parser.add_argument("--ckpt_path", type=str, default=None, - help='specific weights npy file to reload for coarse network') + parser.add_argument( + "--no_reload", + action="store_true", + help="do not reload weights from saved ckpt", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="specific weights npy file to reload for coarse network", + ) # batch size - parser.add_argument("--N_rand", type=int, default=32 * 32 * 2, - help='batch size (number of random rays per gradient step)') - parser.add_argument("--chunk_size", type=int, default=1024 * 8, - help='number of rays processed in parallel, decrease if running out of memory') + parser.add_argument( + "--N_rand", + type=int, + default=32 * 32 * 2, + help="batch size (number of random rays per gradient step)", + ) + parser.add_argument( + "--chunk_size", + type=int, + default=1024 * 8, + help="number of rays processed in parallel, decrease if running out of memory", + ) # iterations - parser.add_argument("--N_iters", type=int, default=250001, - help='number of iterations') + parser.add_argument( + "--N_iters", + type=int, + default=250001, + help="number of iterations", + ) # render only - parser.add_argument("--render_splits", type=str, default='test', - help='splits to render') + parser.add_argument( + "--render_splits", + type=str, + default="test", + help="splits to render", + ) # cascade training - parser.add_argument("--cascade_level", type=int, default=2, - help='number of cascade levels') - parser.add_argument("--cascade_samples", type=str, default='64,64', - help='samples at each level') + parser.add_argument( + "--cascade_level", + type=int, + default=2, + help="number of cascade levels", + ) + parser.add_argument( + "--cascade_samples", + type=str, + default="64,64", + help="samples at each level", + ) # multiprocess learning - parser.add_argument("--world_size", type=int, default='-1', - help='number of processes') + parser.add_argument( + "--world_size", + type=int, + default="-1", + help="number of processes (GPU). defaults to -1 for every GPU.", + ) # optimize autoexposure - parser.add_argument("--optim_autoexpo", action='store_true', - help='optimize autoexposure parameters') - parser.add_argument("--lambda_autoexpo", type=float, default=1., help='regularization weight for autoexposure') + parser.add_argument( + "--optim_autoexpo", + action="store_true", + help="optimize autoexposure parameters", + ) + parser.add_argument( + "--lambda_autoexpo", + type=float, + default=1.0, + help="regularization weight for autoexposure", + ) # learning rate options - parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') - parser.add_argument("--lrate_decay_factor", type=float, default=0.1, - help='decay learning rate by a factor every specified number of steps') - parser.add_argument("--lrate_decay_steps", type=int, default=5000, - help='decay learning rate by a factor every specified number of steps') + parser.add_argument( + "--lrate", + type=float, + default=5e-4, + help="learning rate", + ) + parser.add_argument( + "--lrate_decay_factor", + type=float, + default=0.1, + help="decay learning rate by a factor every specified number of steps", + ) + parser.add_argument( + "--lrate_decay_steps", + type=int, + default=5000, + help="decay learning rate by a factor every specified number of steps", + ) # rendering options - parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples') - parser.add_argument("--max_freq_log2", type=int, default=10, - help='log2 of max freq for positional encoding (3D location)') - parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4, - help='log2 of max freq for positional encoding (2D direction)') - parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth') + parser.add_argument( + "--det", + action="store_true", + help="deterministic sampling for coarse and fine samples", + ) + parser.add_argument( + "--max_freq_log2", + type=int, + default=10, + help="log2 of max freq for positional encoding (3D location)", + ) + parser.add_argument( + "--max_freq_log2_viewdirs", + type=int, + default=4, + help="log2 of max freq for positional encoding (2D direction)", + ) + parser.add_argument( + "--load_min_depth", + action="store_true", + help="whether to load min depth", + ) # logging/saving options - parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') - parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') - parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') + parser.add_argument( + "--i_print", + type=int, + default=100, + help="frequency of console printout and metric loggin", + ) + parser.add_argument( + "--i_img", + type=int, + default=500, + help="frequency of tensorboard image logging", + ) + parser.add_argument( + "--i_weights", + type=int, + default=10000, + help="frequency of weight ckpt saving", + ) return parser @@ -592,15 +962,12 @@ def train(): if args.world_size == -1: args.world_size = torch.cuda.device_count() - logger.info('Using # gpus: {}'.format(args.world_size)) - torch.multiprocessing.spawn(ddp_train_nerf, - args=(args,), - nprocs=args.world_size, - join=True) + logger.info("Using # gpus: {}".format(args.world_size)) + torch.multiprocessing.spawn( + ddp_train_nerf, args=(args,), nprocs=args.world_size, join=True + ) -if __name__ == '__main__': +if __name__ == "__main__": setup_logger() train() - -