import torch import torch.nn as nn import torch.optim import torch.distributed from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing import os from collections import OrderedDict from ddp_model import NerfNet import time # from data_loader import load_data # from nerf_sample_ray import RaySamplerSingleImage from data_loader_split import load_data_split import numpy as np from tensorboardX import SummaryWriter from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER import logging logger = logging.getLogger(__package__) def setup_logger(): # create logger logger = logging.getLogger(__package__) logger.setLevel(logging.DEBUG) # create console handler and set level to debug ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) # create formatter formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') # add formatter to ch ch.setFormatter(formatter) # add ch to logger logger.addHandler(ch) def intersect_sphere(ray_o, ray_d): ''' ray_o, ray_d: [..., 3] compute the depth of the intersection point between this ray and unit sphere ''' # note: d1 becomes negative if this mid point is behind camera d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) p = ray_o + d1.unsqueeze(-1) * ray_d # consider the case where the ray does not intersect the sphere ray_d_cos = 1. / torch.norm(ray_d, dim=-1) d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos return d1 + d2 def perturb_samples(z_vals): # get intervals between samples mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) # uniform samples in those intervals t_rand = torch.rand_like(z_vals) z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] return z_vals def sample_pdf(bins, weights, N_samples, det=False): ''' :param bins: tensor of shape [..., M+1], M is the number of bins :param weights: tensor of shape [..., M] :param N_samples: number of samples along each ray :param det: if True, will perform deterministic sampling :return: [..., N_samples] ''' # Get pdf weights = weights + TINY_NUMBER # prevent nans pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] cdf = torch.cumsum(pdf, dim=-1) # [..., M] cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] min_cdf = 0.00 max_cdf = 1.00 # prevent outlier samples if det: u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples] else: sh = dots_sh + [N_samples] u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples] # Invert CDF # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long() # random sample inside each bin below_inds = torch.clamp(above_inds-1, min=0) inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2] cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2] bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2] # fix numeric issue denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] denom = torch.where(denom1 writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step) rgb_im = img_HWC2CHW(log_data[m]['fg_rgb']) rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step) depth = log_data[m]['fg_depth'] depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, mask=mask)) writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step) rgb_im = img_HWC2CHW(log_data[m]['bg_rgb']) rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step) depth = log_data[m]['bg_depth'] depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, mask=mask)) writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step) bg_lambda = log_data[m]['bg_lambda'] bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True, mask=mask)) writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step) def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) def cleanup(): torch.distributed.destroy_process_group() def ddp_train_nerf(rank, args): ###### set up multi-processing setup(rank, args.world_size) ###### set up logger logger = logging.getLogger(__package__) setup_logger() ###### Create log dir and copy the config file if rank == 0: os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) f = os.path.join(args.basedir, args.expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(args.basedir, args.expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) torch.distributed.barrier() ###### load data and create ray samplers; each process should do this # data = load_data(args.datadir, args.scene, args.testskip) # ray_samplers = [] # for i in data['i_train']: # ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i], # img=data['images'][i], # img_path=data['paths'][i], # mask=data['masks'][i], # min_depth=data['min_depths'][i])) # # val_ray_samplers = [] # for i in data['i_val']: # val_ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i], # img=data['images'][i], # img_path=data['paths'][i], # mask=data['masks'][i], # min_depth=data['min_depths'][i])) # # free memory # del data ray_samplers = load_data_split(args.datadir, args.scene, split='train') val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation') ###### create network and wrap in ddp; each process should do this # fix random seed just to make sure the network is initialized with same weights at different processes torch.manual_seed(777) # very important!!! otherwise it might introduce extra memory in rank=0 gpu torch.cuda.set_device(rank) models = OrderedDict() models['cascade_level'] = args.cascade_level models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')] for m in range(models['cascade_level']): net = NerfNet(args).to(rank) net = DDP(net, device_ids=[rank], output_device=rank) optim = torch.optim.Adam(net.parameters(), lr=args.lrate) models['net_{}'.format(m)] = net models['optim_{}'.format(m)] = optim start = -1 ###### load pretrained weights; each process should do this if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)): ckpts = [args.ckpt_path] else: ckpts = [os.path.join(args.basedir, args.expname, f) for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')] def path2iter(path): tmp = os.path.basename(path)[:-4] idx = tmp.rfind('_') return int(tmp[idx + 1:]) ckpts = sorted(ckpts, key=path2iter) logger.info('Found ckpts: {}'.format(ckpts)) if len(ckpts) > 0 and not args.no_reload: fpath = ckpts[-1] logger.info('Reloading from: {}'.format(fpath)) start = path2iter(fpath) # configure map_location properly for different processes map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} to_load = torch.load(fpath, map_location=map_location) for m in range(models['cascade_level']): for name in ['net_{}'.format(m), 'optim_{}'.format(m)]: models[name].load_state_dict(to_load[name]) models[name].load_state_dict(to_load[name]) ##### important!!! # make sure different processes sample different rays np.random.seed((rank + 1) * 777) # make sure different processes have different perturbations in depth samples torch.manual_seed((rank + 1) * 777) ##### only main process should do the logging if rank == 0: writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname)) # start training what_val_to_log = 0 # helper variable for parallel rendering of a image what_train_to_log = 0 for global_step in range(start+1, start+1+args.N_iters): time0 = time.time() scalars_to_log = OrderedDict() ### Start of core optimization loop scalars_to_log['resolution'] = ray_samplers[0].resolution_level # randomly sample rays and move to device i = np.random.randint(low=0, high=len(ray_samplers)) ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False) for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch[key] = ray_batch[key].to(rank) # forward and backward dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays all_rets = [] # results on different cascade levels for m in range(models['cascade_level']): optim = models['optim_{}'.format(m)] net = models['net_{}'.format(m)] # sample depths N_samples = models['cascade_samples'][m] if m == 0: # foreground depth fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,] # fg_near_depth = 0.18 * torch.ones_like(fg_far_depth) fg_near_depth = ray_batch['min_depth'] # [..., 3] step = (fg_far_depth - fg_near_depth) / (N_samples - 1) fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] fg_depth = perturb_samples(fg_depth) # random perturbation during training # background depth bg_depth = torch.linspace(0., 1., N_samples).view( [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank) bg_depth = perturb_samples(bg_depth) # random perturbation during training else: # sample pdf and concat with earlier samples fg_weights = ret['fg_weights'].clone().detach() fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1] fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights, N_samples=N_samples, det=False) # [..., N_samples] fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1)) # sample pdf and concat with earlier samples bg_weights = ret['bg_weights'].clone().detach() bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights, N_samples=N_samples, det=False) # [..., N_samples] bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1)) optim.zero_grad() ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth) all_rets.append(ret) rgb_gt = ray_batch['rgb'].to(rank) loss = img2mse(ret['rgb'], rgb_gt) scalars_to_log['level_{}/loss'.format(m)] = loss.item() scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item()) # regularize sigma with photo-consistency loss = loss + img2mse(ret['diffuse_rgb'], rgb_gt) loss.backward() optim.step() # # clean unused memory # torch.cuda.empty_cache() ### end of core optimization loop dt = time.time() - time0 scalars_to_log['iter_time'] = dt ### only main process should do the logging if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): logstr = '{} step: {} '.format(args.expname, global_step) for k in scalars_to_log: logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) writer.add_scalar(k, scalars_to_log[k], global_step) logger.info(logstr) ### each process should do this; but only main process merges the results if global_step % args.i_img == 0 or global_step == start+1: #### critical: make sure each process is working on the same random image time0 = time.time() idx = what_val_to_log % len(val_ray_samplers) log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size) what_val_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info('Logged a random validation view in {} seconds'.format(dt)) log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].img_orig, mask=None, prefix='val/') time0 = time.time() idx = what_train_to_log % len(ray_samplers) log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) what_train_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info('Logged a random training view in {} seconds'.format(dt)) log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].img_orig, mask=None, prefix='train/') log_data = None torch.cuda.empty_cache() if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0): # saving checkpoints and logging fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step)) to_save = OrderedDict() for m in range(models['cascade_level']): name = 'net_{}'.format(m) to_save[name] = models[name].state_dict() name = 'optim_{}'.format(m) to_save[name] = models[name].state_dict() torch.save(to_save, fpath) # clean up for multi-processing cleanup() def config_parser(): import configargparse parser = configargparse.ArgumentParser() parser.add_argument('--config', is_config_file=True, help='config file path') parser.add_argument("--expname", type=str, help='experiment name') parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') # dataset options parser.add_argument("--datadir", type=str, default=None, help='input data directory') parser.add_argument("--scene", type=str, default=None, help='scene name') parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') # model size parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network') parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network') parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') # checkpoints parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') parser.add_argument("--ckpt_path", type=str, default=None, help='specific weights npy file to reload for coarse network') # batch size parser.add_argument("--N_rand", type=int, default=32 * 32 * 2, help='batch size (number of random rays per gradient step)') parser.add_argument("--chunk_size", type=int, default=1024 * 8, help='number of rays processed in parallel, decrease if running out of memory') # iterations parser.add_argument("--N_iters", type=int, default=250001, help='number of iterations') # cascade training parser.add_argument("--cascade_level", type=int, default=2, help='number of cascade levels') parser.add_argument("--cascade_samples", type=str, default='64,64', help='samples at each level') parser.add_argument("--devices", type=str, default='0,1', help='cuda device for each level') parser.add_argument("--bg_devices", type=str, default='0,2', help='cuda device for the background of each level') parser.add_argument("--world_size", type=int, default='-1', help='number of processes') # mixed precison training parser.add_argument("--opt_level", type=str, default='O1', help='mixed precison training') parser.add_argument("--near_depth", type=float, default=0.1, help='near depth plane') parser.add_argument("--far_depth", type=float, default=50., help='far depth plane') # learning rate options parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') parser.add_argument("--lrate_decay_factor", type=float, default=0.1, help='decay learning rate by a factor every specified number of steps') parser.add_argument("--lrate_decay_steps", type=int, default=5000, help='decay learning rate by a factor every specified number of steps') # rendering options parser.add_argument("--inv_uniform", action='store_true', help='if True, will uniformly sample inverse depths') parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples') parser.add_argument("--max_freq_log2", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') parser.add_argument("--N_iters_perturb", type=int, default=1000, help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck') parser.add_argument("--raw_noise_std", type=float, default=1., help='std dev of noise added to regularize sigma output, 1e0 recommended') parser.add_argument("--white_bkgd", action='store_true', help='apply the trick to avoid fitting to white background') # no training; render only parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') parser.add_argument("--render_train", action='store_true', help='render the training set') parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') # no training; extract mesh only parser.add_argument("--mesh_only", action='store_true', help='do not optimize, extract mesh from pretrained model') parser.add_argument("--N_pts", type=int, default=256, help='voxel resolution; N_pts * N_pts * N_pts') parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50', help='threshold(s) for mesh extraction; can use multiple thresholds') # logging/saving options parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving') parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') return parser def train(): parser = config_parser() args = parser.parse_args() logger.info(parser.format_values()) if args.world_size == -1: args.world_size = torch.cuda.device_count() logger.info('Using # gpus: {}'.format(args.world_size)) torch.multiprocessing.spawn(ddp_train_nerf, args=(args,), nprocs=args.world_size, join=True) if __name__ == '__main__': setup_logger() train()