import json import logging import os import time from collections import OrderedDict import numpy as np import torch import torch.distributed import torch.multiprocessing import torch.nn as nn import torch.optim from tensorboardX import SummaryWriter from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm from data_loader_split import load_data_split from ddp_model import NerfNetWithAutoExpo from utils import TINY_NUMBER, colorize, img2mse, img_HWC2CHW, mse2psnr logger = logging.getLogger(__package__) def setup_logger(): # create logger logger = logging.getLogger(__package__) # logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO) # create console handler and set level to debug ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) # create formatter formatter = logging.Formatter( "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) # add formatter to ch ch.setFormatter(formatter) # add ch to logger logger.addHandler(ch) def intersect_sphere(ray_o, ray_d): """ ray_o, ray_d: [..., 3] compute the depth of the intersection point between this ray and unit sphere """ # note: d1 becomes negative if this mid point is behind camera d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) p = ray_o + d1.unsqueeze(-1) * ray_d # consider the case where the ray does not intersect the sphere ray_d_cos = 1.0 / torch.norm(ray_d, dim=-1) p_norm_sq = torch.sum(p * p, dim=-1) if (p_norm_sq >= 1.0).any(): raise Exception( "Not all your cameras are bounded by the unit sphere; " "please make sure the cameras are normalized properly!" ) d2 = torch.sqrt(1.0 - p_norm_sq) * ray_d_cos return d1 + d2 def perturb_samples(z_vals): # get intervals between samples mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) # uniform samples in those intervals t_rand = torch.rand_like(z_vals) z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] return z_vals def sample_pdf(bins, weights, N_samples, det=False): """ :param bins: tensor of shape [..., M+1], M is the number of bins :param weights: tensor of shape [..., M] :param N_samples: number of samples along each ray :param det: if True, will perform deterministic sampling :return: [..., N_samples] """ # Get pdf weights = weights + TINY_NUMBER # prevent nans pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] cdf = torch.cumsum(pdf, dim=-1) # [..., M] cdf = torch.cat( [torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1 ) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] min_cdf = 0.00 max_cdf = 1.00 # prevent outlier samples if det: u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) u = u.view([1] * len(dots_sh) + [N_samples]).expand( dots_sh + [ N_samples, ] ) # [..., N_samples] else: sh = dots_sh + [N_samples] u = ( torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf ) # [..., N_samples] # Invert CDF # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] above_inds = torch.sum( u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1 ).long() # random sample inside each bin below_inds = torch.clamp(above_inds - 1, min=0) inds_g = torch.stack( (below_inds, above_inds), dim=-1 ) # [..., N_samples, 2] cdf = cdf.unsqueeze(-2).expand( dots_sh + [N_samples, M + 1] ) # [..., N_samples, M+1] cdf_g = torch.gather( input=cdf, dim=-1, index=inds_g ) # [..., N_samples, 2] bins = bins.unsqueeze(-2).expand( dots_sh + [N_samples, M + 1] ) # [..., N_samples, M+1] bins_g = torch.gather( input=bins, dim=-1, index=inds_g ) # [..., N_samples, 2] # fix numeric issue denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] denom = torch.where(denom < TINY_NUMBER, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * ( bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER ) return samples def render_single_image(rank, world_size, models, ray_sampler, chunk_size): ##### parallel rendering of a single image ray_batch = ray_sampler.get_all() if (ray_batch["ray_d"].shape[0] // world_size) * world_size != ray_batch[ "ray_d" ].shape[0]: raise Exception( "Number of pixels in the image is not divisible by the number of GPUs!\n\t# pixels: {}\n\t# GPUs: {}".format( ray_batch["ray_d"].shape[0], world_size ) ) # split into ranks; make sure different processes don't overlap rank_split_sizes = [ ray_batch["ray_d"].shape[0] // world_size, ] * world_size rank_split_sizes[-1] = ray_batch["ray_d"].shape[0] - sum( rank_split_sizes[:-1] ) for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[ rank ].to(rank) # split into chunks and render inside each process ray_batch_split = OrderedDict() for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch_split[key] = torch.split(ray_batch[key], chunk_size) # forward and backward ret_merge_chunk = [OrderedDict() for _ in range(models["cascade_level"])] for s in range(len(ray_batch_split["ray_d"])): ray_o = ray_batch_split["ray_o"][s] ray_d = ray_batch_split["ray_d"][s] min_depth = ray_batch_split["min_depth"][s] dots_sh = list(ray_d.shape[:-1]) for m in range(models["cascade_level"]): net = models["net_{}".format(m)] # sample depths N_samples = models["cascade_samples"][m] if m == 0: # foreground depth fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,] fg_near_depth = min_depth # [..., ] step = (fg_far_depth - fg_near_depth) / (N_samples - 1) fg_depth = torch.stack( [fg_near_depth + i * step for i in range(N_samples)], dim=-1, ) # [..., N_samples] # background depth bg_depth = ( torch.linspace(0.0, 1.0, N_samples) .view( [ 1, ] * len(dots_sh) + [ N_samples, ] ) .expand( dots_sh + [ N_samples, ] ) .to(rank) ) # delete unused memory del fg_near_depth del step torch.cuda.empty_cache() else: # sample pdf and concat with earlier samples fg_weights = ret["fg_weights"].clone().detach() fg_depth_mid = 0.5 * ( fg_depth[..., 1:] + fg_depth[..., :-1] ) # [..., N_samples-1] fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] fg_depth_samples = sample_pdf( bins=fg_depth_mid, weights=fg_weights, N_samples=N_samples, det=True, ) # [..., N_samples] fg_depth, _ = torch.sort( torch.cat((fg_depth, fg_depth_samples), dim=-1) ) # sample pdf and concat with earlier samples bg_weights = ret["bg_weights"].clone().detach() bg_depth_mid = 0.5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] bg_depth_samples = sample_pdf( bins=bg_depth_mid, weights=bg_weights, N_samples=N_samples, det=True, ) # [..., N_samples] bg_depth, _ = torch.sort( torch.cat((bg_depth, bg_depth_samples), dim=-1) ) # delete unused memory del fg_weights del fg_depth_mid del fg_depth_samples del bg_weights del bg_depth_mid del bg_depth_samples torch.cuda.empty_cache() with torch.no_grad(): ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth) for key in ret: if key not in ["fg_weights", "bg_weights"]: if torch.is_tensor(ret[key]): if key not in ret_merge_chunk[m]: ret_merge_chunk[m][key] = [ ret[key].cpu(), ] else: ret_merge_chunk[m][key].append(ret[key].cpu()) ret[key] = None # clean unused memory torch.cuda.empty_cache() # merge results from different chunks for m in range(len(ret_merge_chunk)): for key in ret_merge_chunk[m]: ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0) # merge results from different processes if rank == 0: ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))] for m in range(len(ret_merge_chunk)): for key in ret_merge_chunk[m]: # generate tensors to store results from other processes sh = list(ret_merge_chunk[m][key].shape[1:]) ret_merge_rank[m][key] = [ torch.zeros( *[ size, ] + sh, dtype=torch.float32 ) for size in rank_split_sizes ] torch.distributed.gather( ret_merge_chunk[m][key], ret_merge_rank[m][key] ) ret_merge_rank[m][key] = ( torch.cat(ret_merge_rank[m][key], dim=0) .reshape((ray_sampler.H, ray_sampler.W, -1)) .squeeze() ) # print(m, key, ret_merge_rank[m][key].shape) else: # send results to main process for m in range(len(ret_merge_chunk)): for key in ret_merge_chunk[m]: torch.distributed.gather(ret_merge_chunk[m][key]) # only rank 0 program returns if rank == 0: return ret_merge_rank else: return None def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=""): rgb_im = img_HWC2CHW(torch.from_numpy(gt_img)) writer.add_image(prefix + "rgb_gt", rgb_im, global_step) for m in range(len(log_data)): rgb_im = img_HWC2CHW(log_data[m]["rgb"]) rgb_im = torch.clamp( rgb_im, min=0.0, max=1.0 ) # just in case diffuse+specular>1 writer.add_image( prefix + "level_{}/rgb".format(m), rgb_im, global_step ) rgb_im = img_HWC2CHW(log_data[m]["fg_rgb"]) rgb_im = torch.clamp( rgb_im, min=0.0, max=1.0 ) # just in case diffuse+specular>1 writer.add_image( prefix + "level_{}/fg_rgb".format(m), rgb_im, global_step ) depth = log_data[m]["fg_depth"] depth_im = img_HWC2CHW( colorize(depth, cmap_name="jet", append_cbar=True, mask=mask) ) writer.add_image( prefix + "level_{}/fg_depth".format(m), depth_im, global_step ) rgb_im = img_HWC2CHW(log_data[m]["bg_rgb"]) rgb_im = torch.clamp( rgb_im, min=0.0, max=1.0 ) # just in case diffuse+specular>1 writer.add_image( prefix + "level_{}/bg_rgb".format(m), rgb_im, global_step ) depth = log_data[m]["bg_depth"] depth_im = img_HWC2CHW( colorize(depth, cmap_name="jet", append_cbar=True, mask=mask) ) writer.add_image( prefix + "level_{}/bg_depth".format(m), depth_im, global_step ) bg_lambda = log_data[m]["bg_lambda"] bg_lambda_im = img_HWC2CHW( colorize(bg_lambda, cmap_name="hot", append_cbar=True, mask=mask) ) writer.add_image( prefix + "level_{}/bg_lambda".format(m), bg_lambda_im, global_step ) def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" # port = np.random.randint(12355, 12399) # os.environ['MASTER_PORT'] = '{}'.format(port) os.environ["MASTER_PORT"] = "12355" # initialize the process group torch.distributed.init_process_group( "gloo", rank=rank, world_size=world_size ) def cleanup(): torch.distributed.destroy_process_group() def create_nerf(rank, args): ###### create network and wrap in ddp; each process should do this # fix random seed just to make sure the network is initialized with same weights at different processes torch.manual_seed(777) # very important!!! otherwise it might introduce extra memory in rank=0 gpu torch.cuda.set_device(rank) models = OrderedDict() models["cascade_level"] = args.cascade_level models["cascade_samples"] = [ int(x.strip()) for x in args.cascade_samples.split(",") ] for m in range(models["cascade_level"]): img_names = None if args.optim_autoexpo: # load training image names for autoexposure f = os.path.join(args.basedir, args.expname, "train_images.json") with open(f) as file: img_names = json.load(file) net = NerfNetWithAutoExpo( args, optim_autoexpo=args.optim_autoexpo, img_names=img_names ).to(rank) net = DDP( net, device_ids=[rank], output_device=rank, find_unused_parameters=True, ) # net = DDP(net, device_ids=[rank], output_device=rank) optim = torch.optim.Adam(net.parameters(), lr=args.lrate) models["net_{}".format(m)] = net models["optim_{}".format(m)] = optim start = -1 ###### load pretrained weights; each process should do this if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)): ckpts = [args.ckpt_path] else: ckpts = [ os.path.join(args.basedir, args.expname, f) for f in sorted( os.listdir(os.path.join(args.basedir, args.expname)) ) if f.endswith(".pth") ] def path2iter(path): tmp = os.path.basename(path)[:-4] idx = tmp.rfind("_") return int(tmp[idx + 1 :]) ckpts = sorted(ckpts, key=path2iter) logger.info("Found ckpts: {}".format(ckpts)) if len(ckpts) > 0 and not args.no_reload: fpath = ckpts[-1] logger.info("Reloading from: {}".format(fpath)) start = path2iter(fpath) # configure map_location properly for different processes map_location = {"cuda:%d" % 0: "cuda:%d" % rank} to_load = torch.load(fpath, map_location=map_location) for m in range(models["cascade_level"]): for name in ["net_{}".format(m), "optim_{}".format(m)]: models[name].load_state_dict(to_load[name]) return start, models def ddp_train_nerf(rank, args): ###### set up multi-processing setup(rank, args.world_size) ###### set up logger logger = logging.getLogger(__package__) setup_logger() ###### decide chunk size according to gpu memory logger.info( "gpu_mem: {}".format( torch.cuda.get_device_properties(rank).total_memory ) ) if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14: logger.info("setting batch size according to 24G gpu") args.N_rand = 1024 args.chunk_size = 8192 else: logger.info("setting batch size according to 12G gpu") args.N_rand = 512 args.chunk_size = 4096 ###### Create log dir and copy the config file if rank == 0: os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) f = os.path.join(args.basedir, args.expname, "args.txt") with open(f, "w") as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write("{} = {}\n".format(arg, attr)) if args.config is not None: f = os.path.join(args.basedir, args.expname, "config.txt") with open(f, "w") as file: file.write(open(args.config, "r").read()) torch.distributed.barrier() ray_samplers = load_data_split( args.datadir, args.scene, split="train", try_load_min_depth=args.load_min_depth, ) val_ray_samplers = load_data_split( args.datadir, args.scene, split="validation", try_load_min_depth=args.load_min_depth, skip=args.testskip, ) # write training image names for autoexposure if args.optim_autoexpo: f = os.path.join(args.basedir, args.expname, "train_images.json") with open(f, "w") as file: img_names = [ ray_samplers[i].img_path for i in range(len(ray_samplers)) ] json.dump(img_names, file, indent=2) ###### create network and wrap in ddp; each process should do this start, models = create_nerf(rank, args) ##### important!!! # make sure different processes sample different rays np.random.seed((rank + 1) * 777) # make sure different processes have different perturbations in depth samples torch.manual_seed((rank + 1) * 777) ##### only main process should do the logging if rank == 0: writer = SummaryWriter( os.path.join(args.basedir, "summaries", args.expname) ) # start training what_val_to_log = 0 # helper variable for parallel rendering of a image what_train_to_log = 0 for global_step in tqdm(range(start + 1, start + 1 + args.N_iters)): time0 = time.time() scalars_to_log = OrderedDict() ### Start of core optimization loop scalars_to_log["resolution"] = ray_samplers[0].resolution_level # randomly sample rays and move to device i = np.random.randint(low=0, high=len(ray_samplers)) ray_batch = ray_samplers[i].random_sample( args.N_rand, center_crop=False ) for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch[key] = ray_batch[key].to(rank) # forward and backward dots_sh = list(ray_batch["ray_d"].shape[:-1]) # number of rays all_rets = [] # results on different cascade levels for m in range(models["cascade_level"]): optim = models["optim_{}".format(m)] net = models["net_{}".format(m)] # sample depths N_samples = models["cascade_samples"][m] if m == 0: # foreground depth fg_far_depth = intersect_sphere( ray_batch["ray_o"], ray_batch["ray_d"] ) # [...,] fg_near_depth = ray_batch["min_depth"] # [..., ] step = (fg_far_depth - fg_near_depth) / (N_samples - 1) fg_depth = torch.stack( [fg_near_depth + i * step for i in range(N_samples)], dim=-1, ) # [..., N_samples] fg_depth = perturb_samples( fg_depth ) # random perturbation during training # background depth bg_depth = ( torch.linspace(0.0, 1.0, N_samples) .view( [ 1, ] * len(dots_sh) + [ N_samples, ] ) .expand( dots_sh + [ N_samples, ] ) .to(rank) ) bg_depth = perturb_samples( bg_depth ) # random perturbation during training else: # sample pdf and concat with earlier samples fg_weights = ret["fg_weights"].clone().detach() fg_depth_mid = 0.5 * ( fg_depth[..., 1:] + fg_depth[..., :-1] ) # [..., N_samples-1] fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] fg_depth_samples = sample_pdf( bins=fg_depth_mid, weights=fg_weights, N_samples=N_samples, det=False, ) # [..., N_samples] fg_depth, _ = torch.sort( torch.cat((fg_depth, fg_depth_samples), dim=-1) ) # sample pdf and concat with earlier samples bg_weights = ret["bg_weights"].clone().detach() bg_depth_mid = 0.5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] bg_depth_samples = sample_pdf( bins=bg_depth_mid, weights=bg_weights, N_samples=N_samples, det=False, ) # [..., N_samples] bg_depth, _ = torch.sort( torch.cat((bg_depth, bg_depth_samples), dim=-1) ) optim.zero_grad() ret = net( ray_batch["ray_o"], ray_batch["ray_d"], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch["img_name"], ) all_rets.append(ret) rgb_gt = ray_batch["rgb"].to(rank) if "autoexpo" in ret: scale, shift = ret["autoexpo"] scalars_to_log[ "level_{}/autoexpo_scale".format(m) ] = scale.item() scalars_to_log[ "level_{}/autoexpo_shift".format(m) ] = shift.item() # rgb_gt = scale * rgb_gt + shift rgb_pred = (ret["rgb"] - shift) / scale rgb_loss = img2mse(rgb_pred, rgb_gt) loss = rgb_loss + args.lambda_autoexpo * ( torch.abs(scale - 1.0) + torch.abs(shift) ) else: rgb_loss = img2mse(ret["rgb"], rgb_gt) loss = rgb_loss scalars_to_log["level_{}/loss".format(m)] = rgb_loss.item() scalars_to_log["level_{}/pnsr".format(m)] = mse2psnr( rgb_loss.item() ) loss.backward() optim.step() # # clean unused memory # torch.cuda.empty_cache() ### end of core optimization loop dt = time.time() - time0 scalars_to_log["iter_time"] = dt ### only main process should do the logging if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): logstr = "{} step: {} ".format(args.expname, global_step) for k in scalars_to_log: logstr += " {}: {:.6f}".format(k, scalars_to_log[k]) writer.add_scalar(k, scalars_to_log[k], global_step) logger.info(logstr) ### each process should do this; but only main process merges the results if global_step % args.i_img == 0 or global_step == start + 1: #### critical: make sure each process is working on the same random image time0 = time.time() idx = what_val_to_log % len(val_ray_samplers) log_data = render_single_image( rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size, ) what_val_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info( "Logged a random validation view in {} seconds".format(dt) ) log_view_to_tb( writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix="val/", ) time0 = time.time() idx = what_train_to_log % len(ray_samplers) log_data = render_single_image( rank, args.world_size, models, ray_samplers[idx], args.chunk_size, ) what_train_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info( "Logged a random training view in {} seconds".format(dt) ) log_view_to_tb( writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix="train/", ) del log_data torch.cuda.empty_cache() if rank == 0 and ( global_step % args.i_weights == 0 and global_step > 0 ): # saving checkpoints and logging fpath = os.path.join( args.basedir, args.expname, "model_{:06d}.pth".format(global_step), ) to_save = OrderedDict() for m in range(models["cascade_level"]): name = "net_{}".format(m) to_save[name] = models[name].state_dict() name = "optim_{}".format(m) to_save[name] = models[name].state_dict() torch.save(to_save, fpath) # clean up for multi-processing cleanup() def config_parser(): import configargparse parser = configargparse.ArgumentParser() parser.add_argument( "--config", is_config_file=True, help="config file path", ) parser.add_argument( "--expname", type=str, help="experiment name", ) parser.add_argument( "--basedir", type=str, default="./logs/", help="where to store ckpts and logs", ) # dataset options parser.add_argument( "--datadir", type=str, default=None, help="input data directory", ) parser.add_argument( "--scene", type=str, default=None, help="scene name", ) parser.add_argument( "--testskip", type=int, default=8, help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", ) # model size parser.add_argument( "--netdepth", type=int, default=8, help="layers in coarse network", ) parser.add_argument( "--netwidth", type=int, default=256, help="channels per layer in coarse network", ) parser.add_argument( "--use_viewdirs", action="store_true", help="use full 5D input instead of 3D", ) # checkpoints parser.add_argument( "--no_reload", action="store_true", help="do not reload weights from saved ckpt", ) parser.add_argument( "--ckpt_path", type=str, default=None, help="specific weights npy file to reload for coarse network", ) # batch size parser.add_argument( "--N_rand", type=int, default=32 * 32 * 2, help="batch size (number of random rays per gradient step)", ) parser.add_argument( "--chunk_size", type=int, default=1024 * 8, help="number of rays processed in parallel, decrease if running out of memory", ) # iterations parser.add_argument( "--N_iters", type=int, default=250001, help="number of iterations", ) # render only parser.add_argument( "--render_splits", type=str, default="test", help="splits to render", ) # cascade training parser.add_argument( "--cascade_level", type=int, default=2, help="number of cascade levels", ) parser.add_argument( "--cascade_samples", type=str, default="64,64", help="samples at each level", ) # multiprocess learning parser.add_argument( "--world_size", type=int, default="-1", help="number of processes (GPU). defaults to -1 for every GPU.", ) # optimize autoexposure parser.add_argument( "--optim_autoexpo", action="store_true", help="optimize autoexposure parameters", ) parser.add_argument( "--lambda_autoexpo", type=float, default=1.0, help="regularization weight for autoexposure", ) # learning rate options parser.add_argument( "--lrate", type=float, default=5e-4, help="learning rate", ) parser.add_argument( "--lrate_decay_factor", type=float, default=0.1, help="decay learning rate by a factor every specified number of steps", ) parser.add_argument( "--lrate_decay_steps", type=int, default=5000, help="decay learning rate by a factor every specified number of steps", ) # rendering options parser.add_argument( "--det", action="store_true", help="deterministic sampling for coarse and fine samples", ) parser.add_argument( "--max_freq_log2", type=int, default=10, help="log2 of max freq for positional encoding (3D location)", ) parser.add_argument( "--max_freq_log2_viewdirs", type=int, default=4, help="log2 of max freq for positional encoding (2D direction)", ) parser.add_argument( "--load_min_depth", action="store_true", help="whether to load min depth", ) # logging/saving options parser.add_argument( "--i_print", type=int, default=100, help="frequency of console printout and metric loggin", ) parser.add_argument( "--i_img", type=int, default=500, help="frequency of tensorboard image logging", ) parser.add_argument( "--i_weights", type=int, default=10000, help="frequency of weight ckpt saving", ) return parser def train(): parser = config_parser() args = parser.parse_args() logger.info(parser.format_values()) if args.world_size == -1: args.world_size = torch.cuda.device_count() logger.info("Using # gpus: {}".format(args.world_size)) torch.multiprocessing.spawn( ddp_train_nerf, args=(args,), nprocs=args.world_size, join=True ) if __name__ == "__main__": setup_logger() train()