clean code

master
Kai-46 4 years ago
parent 18ff4489b3
commit abbb5d136d

5
.gitignore vendored

@ -1,6 +1,9 @@
# scripts
*.sh
# mac
.DS_Store
# pycharm
.idea/
@ -141,4 +144,4 @@ dmypy.json
.pytype/
# Cython debug symbols
cython_debug/
cython_debug/

@ -0,0 +1,23 @@
# NeRF++
Codebase for paper:
* Work with 360 capture of large-scale unbounded scenes.
* Support multi-gpu training and inference.
## Data
* Download our preprocessed data from [tanks_and_temples](), [lf_data]().
* Put the data in the code directory.
* Data format.
** Each scene consists of 3 splits: train/test/validation.
** Intrinsics and poses are stored as flattened 4x4 matrices.
** Opencv camera coordinate system is adopted, i.e., x--->right, y--->down, z--->scene.
* Scene normalization: move the average camera center to origin, and put all the camera centers inside the unit sphere.
## Training
```python
python ddp_train_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt
```
## Testing
```python
python ddp_test_nerf.py --config configs/tanks_and_temples/tat_training_truck.txt --render_splits test,camera_path
```

Binary file not shown.

@ -1,49 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_intermediate_M60
expname = tat_intermediate_M60_bg_carve_latest
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,49 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_intermediate_Playground
expname = tat_intermediate_Playground_bg_carve_latest
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,48 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_intermediate_Playground
expname = tat_intermediate_Playground_ddp_bignet
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
N_rand = 256
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
chunk_size = 4096
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 512
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,49 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_intermediate_Train
expname = tat_intermediate_Train_bg_carve_latest
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,7 +1,7 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
datadir = ./data/tanks_and_temples
scene = tat_training_Truck
expname = tat_training_Truck_ddp_implicit
expname = tat_training_Truck
basedir = ./logs
config = None
ckpt_path = None
@ -10,27 +10,16 @@ testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 512
N_rand = 1024
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 50000000
### implicit
use_implicit = True
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
# chunk_size = 4096
chunk_size = 8192
### RENDERING
@ -39,15 +28,10 @@ max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,48 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_training_Truck
expname = tat_training_Truck_ddp_bignet
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
N_rand = 256
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
chunk_size = 4096
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 512
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,47 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere
scene = tat_training_Truck_subset
expname = tat_training_Truck_subset_bg_carvenew
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 250001
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### CASCADE
cascade_level = 2
cascade_samples = 64,64
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = False
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,54 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_intermediate_Playground
expname = tat_intermediate_Playground_ddp_sparse_addcarve
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### implicit
use_implicit = True
load_min_depth = True
regularize_weight = 0.1
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,54 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_intermediate_Playground
expname = tat_intermediate_Playground_ddp_sparse_addparam
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### implicit
use_implicit = True
load_min_depth = False
regularize_weight = 0.
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,54 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_intermediate_Playground
expname = tat_intermediate_Playground_ddp_sparse_addregularize_pretrain
basedir = ./logs
config = /home/zhangka2/gernot_experi/nerf_bg_latest_ddp/logs/tat_intermediate_Playground_ddp_sparse_addparam/model_210000.pth
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 4096
N_rand = 2048
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 500000
### implicit
use_implicit = True
load_min_depth = False
regularize_weight = 0.1
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
chunk_size = 16384
# chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,55 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_training_Truck
expname = tat_training_Truck_ddp_sparse_addcarve
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 512
N_rand = 1024
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 50000000
### implicit
use_implicit = True
load_min_depth = True
regularize_weight = 0.1
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
# chunk_size = 4096
chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,55 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_training_Truck
expname = tat_training_Truck_ddp_sparse_addparam
basedir = ./logs
config = None
ckpt_path = None
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 512
N_rand = 1024
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 50000000
### implicit
use_implicit = True
load_min_depth = False
regularize_weight = 0.
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
# chunk_size = 4096
chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -1,55 +0,0 @@
### INPUT
datadir = /home/zhangka2/gernot_experi/gernot_data/gernot_nerf_sphere_sparse
scene = tat_training_Truck
expname = tat_training_Truck_ddp_sparse_addregularize_pretrain
basedir = ./logs
config = None
ckpt_path = /home/zhangka2/gernot_experi/nerf_bg_latest_ddp/logs/tat_training_Truck_ddp_sparse_addparam/model_245000.pth
no_reload = False
testskip = 1
### TRAINING
N_iters = 1250001
# N_rand = 512
N_rand = 1024
lrate = 0.0005
lrate_decay_factor = 0.1
lrate_decay_steps = 50000000
### implicit
use_implicit = True
load_min_depth = False
regularize_weight = 0.1
### CASCADE
cascade_level = 2
cascade_samples = 64,128
near_depth = 0.
far_depth = 1.
### TESTING
render_only = False
render_test = False
render_train = False
# chunk_size = 16384
# chunk_size = 4096
chunk_size = 8192
### RENDERING
det = False
max_freq_log2 = 10
max_freq_log2_viewdirs = 4
netdepth = 8
netwidth = 256
raw_noise_std = 1.0
N_iters_perturb = 1000
inv_uniform = False
use_viewdirs = True
white_bkgd = False
### CONSOLE AND TENSORBOARD
i_img = 2000
i_print = 100
i_testset = 5000000
i_video = 5000000
i_weights = 5000

@ -24,16 +24,22 @@ def find_files(dir, exts):
return []
def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True):
def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True, only_img_files=False):
def parse_txt(filename):
assert os.path.isfile(filename)
nums = open(filename).read().split()
return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
split_dir = '{}/{}/{}'.format(basedir, scene, split)
if only_img_files:
img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg'])
return img_files
# camera parameters files
intrinsics_files = find_files('{}/intrinsics'.format(split_dir), exts=['*.txt'])
pose_files = find_files('{}/pose'.format(split_dir), exts=['*.txt'])
logger.info('raw intrinsics_files: {}'.format(len(intrinsics_files)))
logger.info('raw pose_files: {}'.format(len(pose_files)))
@ -49,6 +55,7 @@ def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True):
assert(len(img_files) == cam_cnt)
else:
img_files = [None, ] * cam_cnt
# mask files
mask_files = find_files('{}/mask'.format(split_dir), exts=['*.png', '*.jpg'])
if len(mask_files) > 0:
@ -67,11 +74,12 @@ def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True):
else:
mindepth_files = [None, ] * cam_cnt
# assume all images have the same size
# assume all images have the same size as training image
train_imgfile = find_files('{}/{}/train/rgb'.format(basedir, scene), exts=['*.png', '*.jpg'])[0]
train_im = imageio.imread(train_imgfile)
H, W = train_im.shape[:2]
# create ray samplers
ray_samplers = []
for i in range(cam_cnt):
intrinsics = parse_txt(intrinsics_files[i])

@ -5,6 +5,9 @@ import torch.nn as nn
from utils import TINY_NUMBER, HUGE_NUMBER
from collections import OrderedDict
from nerf_network import Embedder, MLPNet
import os
import logging
logger = logging.getLogger(__package__)
######################################################################################
@ -44,14 +47,6 @@ def depth2pts_outside(ray_o, ray_d, depth):
class NerfNet(nn.Module):
def __init__(self, args):
'''
:param D: network depth
:param W: network width
:param input_ch: input channels for encodings of (x, y, z)
:param input_ch_viewdirs: input channels for encodings of view directions
:param skips: skip connection in network
:param use_viewdirs: if True, will use the view directions as input
'''
super().__init__()
# foreground
self.fg_embedder_position = Embedder(input_dim=3,
@ -146,3 +141,48 @@ class NerfNet(nn.Module):
('bg_depth', bg_depth_map),
('bg_lambda', bg_lambda)])
return ret
def remap_name(name):
name = name.replace('.', '-') # dot is not allowed by pytorch
if name[-1] == '/':
name = name[:-1]
idx = name.rfind('/')
for i in range(2):
if idx >= 0:
idx = name[:idx].rfind('/')
return name[idx + 1:]
class NerfNetWithAutoExpo(nn.Module):
def __init__(self, args, optim_autoexpo=False, img_names=None):
super().__init__()
self.nerf_net = NerfNet(args)
self.optim_autoexpo = optim_autoexpo
if self.optim_autoexpo:
assert(img_names is not None)
logger.info('Optimizing autoexposure!')
self.img_names = [remap_name(x) for x in img_names]
logger.info('\n'.join(self.img_names))
self.autoexpo_params = nn.ParameterDict(OrderedDict([(x, nn.Parameter(torch.Tensor([0.5, 0.]))) for x in self.img_names]))
def forward(self, ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals, img_name=None):
'''
:param ray_o, ray_d: [..., 3]
:param fg_z_max: [...,]
:param fg_z_vals, bg_z_vals: [..., N_samples]
:return
'''
ret = self.nerf_net(ray_o, ray_d, fg_z_max, fg_z_vals, bg_z_vals)
if img_name is not None:
img_name = remap_name(img_name)
if self.optim_autoexpo and (img_name in self.autoexpo_params):
autoexpo = self.autoexpo_params[img_name]
scale = torch.abs(autoexpo[0]) + 0.5 # make sure scale is always positive
shift = autoexpo[1]
ret['autoexpo'] = (scale, shift)
return ret

@ -2,17 +2,17 @@ import torch
# import torch.nn as nn
import torch.optim
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP
# from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing
import numpy as np
import os
from collections import OrderedDict
from ddp_model import NerfNet
# from collections import OrderedDict
# from ddp_model import NerfNet
import time
from data_loader_split import load_data_split
from utils import mse2psnr, colorize_np, to8b
import imageio
from ddp_run_nerf import config_parser, setup_logger, setup, cleanup, render_single_image
from ddp_train_nerf import config_parser, setup_logger, setup, cleanup, render_single_image, create_nerf
import logging
@ -37,46 +37,7 @@ def ddp_test_nerf(rank, args):
args.chunk_size = 4096
###### create network and wrap in ddp; each process should do this
# fix random seed just to make sure the network is initialized with same weights at different processes
torch.manual_seed(777)
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
torch.cuda.set_device(rank)
models = OrderedDict()
models['cascade_level'] = args.cascade_level
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
for m in range(models['cascade_level']):
net = NerfNet(args).to(rank)
net = DDP(net, device_ids=[rank], output_device=rank)
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
models['net_{}'.format(m)] = net
models['optim_{}'.format(m)] = optim
start = -1
###### load pretrained weights; each process should do this
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
ckpts = [args.ckpt_path]
else:
ckpts = [os.path.join(args.basedir, args.expname, f)
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
def path2iter(path):
tmp = os.path.basename(path)[:-4]
idx = tmp.rfind('_')
return int(tmp[idx + 1:])
ckpts = sorted(ckpts, key=path2iter)
logger.info('Found ckpts: {}'.format(ckpts))
if len(ckpts) > 0 and not args.no_reload:
fpath = ckpts[-1]
logger.info('Reloading from: {}'.format(fpath))
start = path2iter(fpath)
# configure map_location properly for different processes
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
to_load = torch.load(fpath, map_location=map_location)
for m in range(models['cascade_level']):
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
models[name].load_state_dict(to_load[name])
models[name].load_state_dict(to_load[name])
start, models = create_nerf(rank, args)
render_splits = [x.strip() for x in args.render_splits.strip().split(',')]
# start testing
@ -157,4 +118,3 @@ if __name__ == '__main__':
setup_logger()
test()

@ -1,18 +1,20 @@
import torch
# import torch.nn as nn
import torch.nn as nn
import torch.optim
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing
import os
from collections import OrderedDict
from ddp_model import NerfNet
from ddp_model import NerfNetWithAutoExpo
import time
from data_loader_split import load_data_split
import numpy as np
from tensorboardX import SummaryWriter
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
import logging
import json
logger = logging.getLogger(__package__)
@ -274,41 +276,7 @@ def cleanup():
torch.distributed.destroy_process_group()
def ddp_train_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory))
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
logger.info('setting batch size according to 24G gpu')
args.N_rand = 1024
args.chunk_size = 8192
else:
logger.info('setting batch size according to 12G gpu')
args.N_rand = 512
args.chunk_size = 4096
###### Create log dir and copy the config file
if rank == 0:
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
f = os.path.join(args.basedir, args.expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(args.basedir, args.expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
torch.distributed.barrier()
ray_samplers = load_data_split(args.datadir, args.scene, split='train', try_load_min_depth=args.load_min_depth)
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation', try_load_min_depth=args.load_min_depth)
def create_nerf(rank, args):
###### create network and wrap in ddp; each process should do this
# fix random seed just to make sure the network is initialized with same weights at different processes
torch.manual_seed(777)
@ -319,8 +287,15 @@ def ddp_train_nerf(rank, args):
models['cascade_level'] = args.cascade_level
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
for m in range(models['cascade_level']):
net = NerfNet(args).to(rank)
net = DDP(net, device_ids=[rank], output_device=rank)
img_names = None
if args.optim_autoexpo:
# load training image names for autoexposure
f = os.path.join(args.basedir, args.expname, 'train_images.json')
with open(f) as file:
img_names = json.load(file)
net = NerfNetWithAutoExpo(args, optim_autoexpo=args.optim_autoexpo, img_names=img_names).to(rank)
net = DDP(net, device_ids=[rank], output_device=rank, find_unused_parameters=True)
# net = DDP(net, device_ids=[rank], output_device=rank)
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
models['net_{}'.format(m)] = net
models['optim_{}'.format(m)] = optim
@ -351,6 +326,56 @@ def ddp_train_nerf(rank, args):
models[name].load_state_dict(to_load[name])
models[name].load_state_dict(to_load[name])
return start, models
def ddp_train_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
logger.info('gpu_mem: {}'.format(torch.cuda.get_device_properties(rank).total_memory))
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
logger.info('setting batch size according to 24G gpu')
args.N_rand = 1024
args.chunk_size = 8192
else:
logger.info('setting batch size according to 12G gpu')
args.N_rand = 512
args.chunk_size = 4096
###### Create log dir and copy the config file
if rank == 0:
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
f = os.path.join(args.basedir, args.expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(args.basedir, args.expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
torch.distributed.barrier()
ray_samplers = load_data_split(args.datadir, args.scene, split='train',
try_load_min_depth=args.load_min_depth)
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation',
try_load_min_depth=args.load_min_depth, skip=args.testskip)
# write training image names for autoexposure
if args.optim_autoexpo:
f = os.path.join(args.basedir, args.expname, 'train_images.json')
with open(f, 'w') as file:
img_names = [ray_samplers[i].img_path for i in range(len(ray_samplers))]
json.dump(img_names, file, indent=2)
###### create network and wrap in ddp; each process should do this
start, models = create_nerf(rank, args)
##### important!!!
# make sure different processes sample different rays
np.random.seed((rank + 1) * 777)
@ -416,13 +441,23 @@ def ddp_train_nerf(rank, args):
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
optim.zero_grad()
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth)
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name'])
all_rets.append(ret)
rgb_gt = ray_batch['rgb'].to(rank)
loss = img2mse(ret['rgb'], rgb_gt)
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
if 'autoexpo' in ret:
scale, shift = ret['autoexpo']
scalars_to_log['level_{}/autoexpo_scale'.format(m)] = scale.item()
scalars_to_log['level_{}/autoexpo_shift'.format(m)] = shift.item()
# rgb_gt = scale * rgb_gt + shift
rgb_pred = (ret['rgb'] - shift) / scale
rgb_loss = img2mse(rgb_pred, rgb_gt)
loss = rgb_loss + args.lambda_autoexpo * (torch.abs(scale-1.)+torch.abs(shift))
else:
rgb_loss = img2mse(ret['rgb'], rgb_gt)
loss = rgb_loss
scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item()
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(rgb_loss.item())
loss.backward()
optim.step()
@ -462,7 +497,7 @@ def ddp_train_nerf(rank, args):
logger.info('Logged a random training view in {} seconds'.format(dt))
log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/')
log_data = None
del log_data
torch.cuda.empty_cache()
if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
@ -523,6 +558,11 @@ def config_parser():
# multiprocess learning
parser.add_argument("--world_size", type=int, default='-1',
help='number of processes')
# optimize autoexposure
parser.add_argument("--optim_autoexpo", action='store_true',
help='optimize autoexposure parameters')
parser.add_argument("--lambda_autoexpo", type=float, default=1., help='regularization weight for autoexposure')
# learning rate options
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
@ -530,8 +570,6 @@ def config_parser():
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
help='decay learning rate by a factor every specified number of steps')
# rendering options
parser.add_argument("--inv_uniform", action='store_true',
help='if True, will uniformly sample inverse depths')
parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
parser.add_argument("--max_freq_log2", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')

@ -172,11 +172,12 @@ class RaySamplerSingleImage(object):
('depth', depth),
('rgb', rgb),
('mask', mask),
('min_depth', min_depth)
('min_depth', min_depth),
('img_name', self.img_path)
])
# return torch tensors
for k in ret:
if ret[k] is not None:
if isinstance(ret[k], np.ndarray):
ret[k] = torch.from_numpy(ret[k])
return ret

@ -1,19 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C pascal
#SBATCH --mem=40G
#SBATCH --time=24:00:00
#SBATCH --output=slurm_%A.out
#SBATCH --qos=high
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
#$PYTHON -u $CODE_DIR/ddp_test_nerf.py --config $CODE_DIR/configs/lf_data/lf_africa.txt
$PYTHON -u $CODE_DIR/ddp_test_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck.txt

@ -1,24 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:3
#SBATCH -c 8
#SBATCH -C turing
#SBATCH --mem=16G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
echo $CODE_DIR
#$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
#$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
#$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
#$PYTHON -u $CODE_DIR/nerf_render_path.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt
#$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_addregularize.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addparam.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addcarve.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_intermediate_playground_addregularize.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addparam.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:4
#SBATCH -c 10
####SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addcarve.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=60G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples_sparse/tat_training_truck_addregularize.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:8
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=80G
#SBATCH --time=24:00:00
#SBATCH --output=slurm_%A.out
#SBATCH --qos=high
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_africa.txt

@ -1,17 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:8
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=100G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
######## #SBATCH --qos=high
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_basket.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=80G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
#SBATCH --qos=normal
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_ship.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:8
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=80G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
#SBATCH --qos=normal
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/lf_data/lf_torch.txt

@ -1,18 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:3
#SBATCH -c 8
#SBATCH -C turing
#SBATCH --mem=16G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
echo $CODE_DIR
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_m60.txt
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_m60.txt

@ -1,18 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:3
#SBATCH -c 8
#SBATCH -C turing
#SBATCH --mem=16G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
echo $CODE_DIR
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground.txt

@ -1,15 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:8
#SBATCH -c 25
#SBATCH -C turing
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_playground_bignet.txt

@ -1,18 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:3
#SBATCH -c 8
#SBATCH -C turing
#SBATCH --mem=16G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest
echo $CODE_DIR
$PYTHON -u $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_train.txt
$PYTHON -u $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_intermediate_train.txt

@ -1,16 +0,0 @@
#!/bin/bash
#SBATCH -p q6
#SBATCH --gres=gpu:4
#SBATCH -c 10
#SBATCH -C turing
#SBATCH --mem=50G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck.txt

@ -1,15 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:8
#SBATCH -c 25
#SBATCH -C turing
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
PYTHON=/home/zhangka2/anaconda3/envs/nerf-ddp/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg_latest_ddp
echo $CODE_DIR
$PYTHON -u $CODE_DIR/ddp_run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_bignet.txt

@ -1,19 +0,0 @@
#!/bin/bash
#SBATCH -p gpu
#SBATCH --gres=gpu:3
#SBATCH -c 8
#SBATCH -C turing
#SBATCH --mem=16G
#SBATCH --time=48:00:00
#SBATCH --output=slurm_%A.out
#SBATCH --exclude=isl-gpu17
PYTHON=/home/zhangka2/anaconda3/envs/nerf/bin/python
CODE_DIR=/home/zhangka2/gernot_experi/nerf_bg
echo $CODE_DIR
$PYTHON $CODE_DIR/run_nerf.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_subset.txt
$PYTHON $CODE_DIR/nerf_render_image.py --config $CODE_DIR/configs/tanks_and_temples/tat_training_truck_subset.txt
Loading…
Cancel
Save