Add multi-gpu support
This commit is contained in:
parent
b02703bec5
commit
74a16a54bc
1 changed files with 11 additions and 6 deletions
17
run_nerf.py
17
run_nerf.py
|
@ -188,20 +188,22 @@ def create_nerf(args):
|
|||
skips = [4]
|
||||
model = NeRF(D=args.netdepth, W=args.netwidth,
|
||||
input_ch=input_ch, output_ch=output_ch, skips=skips,
|
||||
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
|
||||
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
|
||||
model = nn.DataParallel(model).to(device)
|
||||
grad_vars = list(model.parameters())
|
||||
|
||||
model_fine = None
|
||||
if args.N_importance > 0:
|
||||
model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
|
||||
input_ch=input_ch, output_ch=output_ch, skips=skips,
|
||||
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
|
||||
input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
|
||||
model_fine = nn.DataParallel(model_fine).to(device)
|
||||
grad_vars += list(model_fine.parameters())
|
||||
|
||||
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
|
||||
embed_fn=embed_fn,
|
||||
embeddirs_fn=embeddirs_fn,
|
||||
netchunk=args.netchunk)
|
||||
netchunk=args.netchunk_per_gpu*args.n_gpus)
|
||||
|
||||
# Create optimizer
|
||||
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
|
||||
|
@ -448,7 +450,7 @@ def config_parser():
|
|||
help='exponential learning rate decay (in 1000 steps)')
|
||||
parser.add_argument("--chunk", type=int, default=1024*32,
|
||||
help='number of rays processed in parallel, decrease if running out of memory')
|
||||
parser.add_argument("--netchunk", type=int, default=1024*64,
|
||||
parser.add_argument("--netchunk_per_gpu", type=int, default=1024*64*4,
|
||||
help='number of pts sent through network in parallel, decrease if running out of memory')
|
||||
parser.add_argument("--no_batching", action='store_true',
|
||||
help='only take random rays from 1 image at a time')
|
||||
|
@ -536,8 +538,11 @@ def train():
|
|||
parser = config_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load data
|
||||
# Multi-GPU
|
||||
args.n_gpus = torch.cuda.device_count()
|
||||
print(f"Using {args.n_gpus} GPU(s).")
|
||||
|
||||
# Load data
|
||||
if args.dataset_type == 'llff':
|
||||
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
|
||||
recenter=True, bd_factor=.75,
|
||||
|
@ -733,7 +738,7 @@ def train():
|
|||
select_coords = coords[select_inds].long() # (N_rand, 2)
|
||||
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||
batch_rays = torch.stack([rays_o, rays_d], 0)
|
||||
batch_rays = torch.stack([rays_o, rays_d], 0) # (2, N_rand, 3)
|
||||
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
|
||||
|
||||
##### Core optimization loop #####
|
||||
|
|
Loading…
Reference in a new issue