Add multi-gpu support

master
yenchenlin 3 years ago
parent b02703bec5
commit 74a16a54bc

@ -188,20 +188,22 @@ def create_nerf(args):
skips = [4]
model = NeRF(D=args.netdepth, W=args.netwidth,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
model = nn.DataParallel(model).to(device)
grad_vars = list(model.parameters())
model_fine = None
if args.N_importance > 0:
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
input_ch=input_ch, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
model_fine = nn.DataParallel(model_fine).to(device)
grad_vars += list(model_fine.parameters())
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)
netchunk=args.netchunk_per_gpu*args.n_gpus)
# Create optimizer
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
@ -448,7 +450,7 @@ def config_parser():
help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64,
parser.add_argument("--netchunk_per_gpu", type=int, default=1024*64*4,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
@ -536,8 +538,11 @@ def train():
parser = config_parser()
args = parser.parse_args()
# Load data
# Multi-GPU
args.n_gpus = torch.cuda.device_count()
print(f"Using {args.n_gpus} GPU(s).")
# Load data
if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75,
@ -733,7 +738,7 @@ def train():
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
batch_rays = torch.stack([rays_o, rays_d], 0) # (2, N_rand, 3)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
##### Core optimization loop #####

Loading…
Cancel
Save