diff --git a/run_nerf.py b/run_nerf.py index fab80da..d9e0c8d 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -188,20 +188,22 @@ def create_nerf(args): skips = [4] model = NeRF(D=args.netdepth, W=args.netwidth, input_ch=input_ch, output_ch=output_ch, skips=skips, - input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) + model = nn.DataParallel(model).to(device) grad_vars = list(model.parameters()) model_fine = None if args.N_importance > 0: model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, input_ch=input_ch, output_ch=output_ch, skips=skips, - input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) + model_fine = nn.DataParallel(model_fine).to(device) grad_vars += list(model_fine.parameters()) network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, - netchunk=args.netchunk) + netchunk=args.netchunk_per_gpu*args.n_gpus) # Create optimizer optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) @@ -448,7 +450,7 @@ def config_parser(): help='exponential learning rate decay (in 1000 steps)') parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') - parser.add_argument("--netchunk", type=int, default=1024*64, + parser.add_argument("--netchunk_per_gpu", type=int, default=1024*64*4, help='number of pts sent through network in parallel, decrease if running out of memory') parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') @@ -536,8 +538,11 @@ def train(): parser = config_parser() args = parser.parse_args() - # Load data + # Multi-GPU + args.n_gpus = torch.cuda.device_count() + print(f"Using {args.n_gpus} GPU(s).") + # Load data if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, recenter=True, bd_factor=.75, @@ -733,7 +738,7 @@ def train(): select_coords = coords[select_inds].long() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) - batch_rays = torch.stack([rays_o, rays_d], 0) + batch_rays = torch.stack([rays_o, rays_d], 0) # (2, N_rand, 3) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop #####