From e60d030d346d04ab630b82e00d82e4599e09a940 Mon Sep 17 00:00:00 2001 From: Kai-46 Date: Sun, 11 Oct 2020 21:36:58 -0400 Subject: [PATCH] clean code --- .gitignore | 138 +++++++ __pycache__/utils.cpython-37.pyc | Bin 4212 -> 0 bytes old_scripts/data_loader.py | 181 --------- old_scripts/ddp_run_nerf.py | 617 ------------------------------- old_scripts/nerf_sample_ray.py | 231 ------------ 5 files changed, 138 insertions(+), 1029 deletions(-) create mode 100644 .gitignore delete mode 100644 __pycache__/utils.cpython-37.pyc delete mode 100644 old_scripts/data_loader.py delete mode 100644 old_scripts/ddp_run_nerf.py delete mode 100644 old_scripts/nerf_sample_ray.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5391d87 --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ \ No newline at end of file diff --git a/__pycache__/utils.cpython-37.pyc b/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index 5f237590e6c468b1f10de6422ed780eb4d9e5d88..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4212 zcmaJ^O>7&-72cU$lFJ{Fk}T`jKeA&4V2etQoVJN;*ohrG4%(<~oFM2n>2}2#QcLY} z>Di_JvOLAspKU^FGVlC_tay0pHqsSbSO|XhZaGB^n1IMthh#(*thfZX6DWN z`QGE*si}&A=QscS%g;LVhVeJ*oIFJ|Uc(z-V}>CN!2-j{zs%u+JEkxlORNe@6rM1r zAc~^&#Bhqj&b}p478UeuQO){gQOmv+VTh@q8rGZ|quuJkR9JWFS(^n7XBuOTpb^eE zGn27dXI5MhGh!B}oD)~YoH&oxoOnUZiv_gK%XzW5!+*)tD%P@K!CA<@i_W50a+W%V zv%LR?T#*Z-3vy}9K4xziA8t4oRb3 z{``mAEt4kj(3kXj4{z*%q+A#~wz`xUBbFK|TQqVyXN<8UM)}uAQ$8{UB%=9jmXEQn zYYDbw%o#~x!B{aKL9dX8ZytYzr&-j6BvP$yZH3b7X)_*#x;*H`j|Wm7N?nMzy}s1d zci-CliM#pU+jsB3(=;{P*X#f*ym;q#hN?l*FTc6|UMy9-ez@)RIy>IRdMgrgy(d-M zZEdvOFj8`TkoZBo)<4iCKkRITv8>ZY>Ba`L**sU&Z^56khrR{8nKo!9HAc{+;KCf6 zxPi5K%rUBn@0duF=~7>+aFEC*Q)QYX(Q-0{mbPyNUbrQ^+jHoBf!9#8AkRONc=991 zY(&RCJy`O0s{^8k4=V{$Jkkcb+J-yj2uc1Llj*TJgHe&dm zSeTRb6P{bJ$u?mHGxTAw(A!rw)-}!i$_P|LEf9I;^5tHnLND+So z%+7}gH6Hk#fs*&U-mVwF-sxbdm=B@FTcI|)GHH=t(nCU*@MhNg0i;eQ9(u%wg|yIR zgnvgsh=iwwkJ+%877(GuF&mZ=Yg9;~#-W`QyQQw3GDLd`>&n6&R+93lLL+HqY!0j7 zsCH{|{k*|lM8?fkkkwr;7|65y zyNDM|dHer6)e617+w($s^X!7a+mb;(pky*oy?nOGJ2$fDi#x{@e#Zi&$#tpicVfht zHbvi`c;mnCz;_DD?`$W>QO;(Oh#J(}KhChj-jyoxTV8+=vPcqbQM>}3{nk#btC8|$ zFYyvT>Yauv!5?|zBfK%ieFs|ous&o07?L6n{E9!ephtFC`sn73wTvNNss)?N$&$>i9MqC##MLR5MV( zDdkLonM!OGjtjn7x*q5V%9r(St@3o}1PcEWs_a;}m>%~&n zV{Ama&&Ai=+t*dId z>LQ6+QZh@fitr776GUEeV?z;~XcKF-+1lNpoI<*DDsOGRd;cBx^>^;S-aMz-wrYY; zeS?_Gr;e)&xj-hI5O>coOc6GzZxLAoaqR38tzI%|!F-O{+ImS>vWa{H7VE;EFOqFt z+?I4X&HYfby=QaE`RDq?w{cGVGKf*Bu^O|@3ahYbZm~MP96y_*#izLqY8PzGE%O)I zG`m1W02MwDTI3+~P#UK!Noj|TwG=^4c zrIa(Q6glG-wF=<2(gM)c=*>-XEqF$JY$U}61L%!;)LYnV5mnRBNJ=AMG|nof#T3Ye zD$)QBo=~3S;IK52?{l0{ruI`Vq|FLpCHCOJa-L{lHEH2gesP%miUFyj_+C*0lrHW!V@ONZp^W5{REdZRC$R6HiImaiGLRD#T)$8|V zPq-~0&lxevyPd+mi>(2kH1|yHgIxhvh$FYx{{@)i7eP>T!Vl|g7Voravl^dYPuz!%@N)C6{Rul!qBzL3#c z*nc_mYxM$-NAfY-{aBYC4^Z|d2UIn4brmzHp7Za6N#nt0^IE15)y8zDG^hITy*KYq z{`2IXOn*Y0D*n5;58uDHaqr>#j@?n-!G`K=Wi_;8CDAKej*Wjd8~wPa^kpPk%36WH zwYKH8c5sjQq~UsPB45Ju+qJkc}NEv~D16g&8L z<#nVVQKRDuUpclH1TH$UF2<2cq|hZw)|k?kpU6;KeGgb~bt4odg?82{=AovGeyF4; za4SFDQ{X=144GR@{#k{Q7zFY*5ZZ{p17e%DX)&uMyJT~FD$53WM)+~%N@c3j Qz#O;WGjl-yI-BGF1!F}1CIA2c diff --git a/old_scripts/data_loader.py b/old_scripts/data_loader.py deleted file mode 100644 index 6c0fef2..0000000 --- a/old_scripts/data_loader.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -import numpy as np -import imageio -from collections import OrderedDict -import logging - -logger = logging.getLogger(__package__) - -######################################################################################################################## -# camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention) -# poses is camera-to-world -######################################################################################################################## - -def load_data(basedir, scene, testskip=8): - def parse_txt(filename): - assert os.path.isfile(filename) - nums = open(filename).read().split() - return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32) - - def dir2poses(posedir): - poses = np.stack( - [parse_txt(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) - poses = poses.astype(np.float32) - return poses - - def dir2intrinsics(intrinsicdir): - intrinsics = np.stack( - [parse_txt(os.path.join(intrinsicdir, f)) for f in sorted(os.listdir(intrinsicdir)) if f.endswith('txt')], 0) - intrinsics = intrinsics.astype(np.float32) - return intrinsics - - intrinsics = dir2intrinsics('{}/{}/train/intrinsics'.format(basedir, scene)) - testintrinsics = dir2poses('{}/{}/test/intrinsics'.format(basedir, scene)) - testintrinsics = testintrinsics[::testskip] - valintrinsics = dir2poses('{}/{}/validation/intrinsics'.format(basedir, scene)) - valintrinsics = valintrinsics[::testskip] - - print(intrinsics.shape, testintrinsics.shape, valintrinsics.shape) - - poses = dir2poses('{}/{}/train/pose'.format(basedir, scene)) - testposes = dir2poses('{}/{}/test/pose'.format(basedir, scene)) - testposes = testposes[::testskip] - valposes = dir2poses('{}/{}/validation/pose'.format(basedir, scene)) - valposes = valposes[::testskip] - - print(poses.shape, testposes.shape, valposes.shape) - - imgd = '{}/{}/train/rgb'.format(basedir, scene) - imgfiles = ['{}/{}'.format(imgd, f) - for f in sorted(os.listdir(imgd)) if f.endswith('png') or f.endswith('jpg')] - imgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in imgfiles] - - maskd = '{}/{}/train/mask'.format(basedir, scene) - if os.path.isdir(maskd): - logger.info('Loading mask from: {}'.format(maskd)) - maskfiles = ['{}/{}'.format(maskd, f) - for f in sorted(os.listdir(maskd)) if f.endswith('png') or f.endswith('jpg')] - masks = [imageio.imread(f).astype(np.float32) / 255. for f in maskfiles] - else: - masks = [None for im in imgs] - - # load min_depth map - min_depthd = '{}/{}/train/min_depth'.format(basedir, scene) - if os.path.isdir(min_depthd): - logger.info('Loading min_depth from: {}'.format(min_depthd)) - max_depth = float(open('{}/{}/train/max_depth.txt'.format(basedir, scene)).readline().strip()) - min_depthfiles = ['{}/{}'.format(min_depthd, f) - for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')] - min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles] - else: - min_depths = [None for im in imgs] - - testimgd = '{}/{}/test/rgb'.format(basedir, scene) - testimgfiles = ['{}/{}'.format(testimgd, f) - for f in sorted(os.listdir(testimgd)) if f.endswith('png') or f.endswith('jpg')] - testimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in testimgfiles] - testimgfiles = testimgfiles[::testskip] - testimgs = testimgs[::testskip] - - testmaskd = '{}/{}/test/mask'.format(basedir, scene) - if os.path.isdir(testmaskd): - logger.info('Loading mask from: {}'.format(testmaskd)) - testmaskfiles = ['{}/{}'.format(testmaskd, f) - for f in sorted(os.listdir(testmaskd)) if f.endswith('png') or f.endswith('jpg')] - testmasks = [imageio.imread(f).astype(np.float32) / 255. for f in testmaskfiles] - else: - testmasks = [None for im in testimgs] - - # load min_depth map - min_depthd = '{}/{}/test/min_depth'.format(basedir, scene) - if os.path.isdir(min_depthd): - logger.info('Loading min_depth from: {}'.format(min_depthd)) - max_depth = float(open('{}/{}/test/max_depth.txt'.format(basedir, scene)).readline().strip()) - min_depthfiles = ['{}/{}'.format(min_depthd, f) - for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')] - test_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles] - else: - test_min_depths = [None for im in testimgs] - - valimgd = '{}/{}/validation/rgb'.format(basedir, scene) - valimgfiles = ['{}/{}'.format(valimgd, f) - for f in sorted(os.listdir(valimgd)) if f.endswith('png') or f.endswith('jpg')] - valimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in valimgfiles] - valimgfiles = valimgfiles[::testskip] - valimgs = valimgs[::testskip] - - valmaskd = '{}/{}/validation/mask'.format(basedir, scene) - if os.path.isdir(valmaskd): - logger.info('Loading mask from: {}'.format(valmaskd)) - valmaskfiles = ['{}/{}'.format(valmaskd, f) - for f in sorted(os.listdir(valmaskd)) if f.endswith('png') or f.endswith('jpg')] - valmasks = [imageio.imread(f).astype(np.float32) / 255. for f in valmaskfiles] - else: - valmasks = [None for im in valimgs] - - # load min_depth map - min_depthd = '{}/{}/validation/min_depth'.format(basedir, scene) - if os.path.isdir(min_depthd): - logger.info('Loading min_depth from: {}'.format(min_depthd)) - max_depth = float(open('{}/{}/validation/max_depth.txt'.format(basedir, scene)).readline().strip()) - min_depthfiles = ['{}/{}'.format(min_depthd, f) - for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')] - val_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles] - else: - val_min_depths = [None for im in valimgs] - - # data format for training/testing - print(len(imgs), len(testimgs), len(valimgs)) - all_imgs = imgs + valimgs + testimgs - all_masks = masks + valmasks + testmasks - all_min_depths = min_depths + val_min_depths + test_min_depths - all_paths = imgfiles + valimgfiles + testimgfiles - - counts = [0] + [len(x) for x in [imgs, valimgs, testimgs]] - counts = np.cumsum(counts) - i_split = [list(np.arange(counts[i], counts[i+1])) for i in range(3)] - - intrinsics = np.concatenate([intrinsics, valintrinsics, testintrinsics], 0) - poses = np.concatenate([poses, valposes, testposes], 0) - img_sizes = np.stack([np.array(x.shape[:2]) for x in all_imgs], axis=0) # [H, W] - cnt = len(all_imgs) - all_cams = np.concatenate((img_sizes.astype(dtype=np.float32), intrinsics.reshape((cnt, -1)), poses.reshape((cnt, -1))), axis=1) - - if os.path.isdir('{}/{}/camera_path/intrinsics'.format(basedir, scene)): - camera_path_intrinsics = dir2poses('{}/{}/camera_path/intrinsics'.format(basedir, scene)) - camera_path_poses = dir2poses('{}/{}/camera_path/pose'.format(basedir, scene)) - # assume centered principal points - # img_sizes = np.stack((camera_path_intrinsics[:, 1, 2]*2, camera_path_intrinsics[:, 0, 2]*2), axis=1) # [H, W] - # img_sizes = np.int32(img_sizes) - - H = all_cams[0, 0] - W = all_cams[0, 1] - img_sizes = np.stack((np.ones_like(camera_path_intrinsics[:, 1, 2])*H, np.ones_like(camera_path_intrinsics[:, 0, 2])*W), axis=1) # [H, W] - - cnt = len(camera_path_intrinsics) - render_cams = np.concatenate( - (img_sizes.astype(dtype=np.float32), camera_path_intrinsics.reshape((cnt, -1)), camera_path_poses.reshape((cnt, -1))), - axis=1) - else: - render_cams = None - - print(all_cams.shape) - - data = OrderedDict([('images', all_imgs), - ('masks', all_masks), - ('paths', all_paths), - ('min_depths', all_min_depths), - ('cameras', all_cams), - ('i_train', i_split[0]), - ('i_val', i_split[1]), - ('i_test', i_split[2]), - ('render_cams', render_cams)]) - - logger.info('Data statistics:') - logger.info('\t # of training views: {}'.format(len(data['i_train']))) - logger.info('\t # of validation views: {}'.format(len(data['i_val']))) - logger.info('\t # of test views: {}'.format(len(data['i_test']))) - if data['render_cams'] is not None: - logger.info('\t # of render cameras: {}'.format(len(data['render_cams']))) - - return data diff --git a/old_scripts/ddp_run_nerf.py b/old_scripts/ddp_run_nerf.py deleted file mode 100644 index 963126c..0000000 --- a/old_scripts/ddp_run_nerf.py +++ /dev/null @@ -1,617 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim -import torch.distributed -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.multiprocessing - -import os -from collections import OrderedDict -from ddp_model import NerfNet -import time -# from data_loader import load_data -# from nerf_sample_ray import RaySamplerSingleImage - -from data_loader_split import load_data_split - -import numpy as np -from tensorboardX import SummaryWriter -from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER - -import logging -logger = logging.getLogger(__package__) - - -def setup_logger(): - # create logger - logger = logging.getLogger(__package__) - logger.setLevel(logging.DEBUG) - - # create console handler and set level to debug - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - - # create formatter - formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') - - # add formatter to ch - ch.setFormatter(formatter) - - # add ch to logger - logger.addHandler(ch) - - -def intersect_sphere(ray_o, ray_d): - ''' - ray_o, ray_d: [..., 3] - compute the depth of the intersection point between this ray and unit sphere - ''' - # note: d1 becomes negative if this mid point is behind camera - d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) - p = ray_o + d1.unsqueeze(-1) * ray_d - # consider the case where the ray does not intersect the sphere - ray_d_cos = 1. / torch.norm(ray_d, dim=-1) - d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos - - return d1 + d2 - - -def perturb_samples(z_vals): - # get intervals between samples - mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) - upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) - lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) - # uniform samples in those intervals - t_rand = torch.rand_like(z_vals) - z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] - - return z_vals - - -def sample_pdf(bins, weights, N_samples, det=False): - ''' - :param bins: tensor of shape [..., M+1], M is the number of bins - :param weights: tensor of shape [..., M] - :param N_samples: number of samples along each ray - :param det: if True, will perform deterministic sampling - :return: [..., N_samples] - ''' - # Get pdf - weights = weights + TINY_NUMBER # prevent nans - pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] - cdf = torch.cumsum(pdf, dim=-1) # [..., M] - cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1] - - # Take uniform samples - dots_sh = list(weights.shape[:-1]) - M = weights.shape[-1] - - min_cdf = 0.00 - max_cdf = 1.00 # prevent outlier samples - - if det: - u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device) - u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples] - else: - sh = dots_sh + [N_samples] - u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples] - - # Invert CDF - # [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,] - above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long() - - # random sample inside each bin - below_inds = torch.clamp(above_inds-1, min=0) - inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2] - - cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2] - - bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1] - bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2] - - # fix numeric issue - denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples] - denom = torch.where(denom1 - writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step) - - rgb_im = img_HWC2CHW(log_data[m]['fg_rgb']) - rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 - writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step) - depth = log_data[m]['fg_depth'] - depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step) - - rgb_im = img_HWC2CHW(log_data[m]['bg_rgb']) - rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1 - writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step) - depth = log_data[m]['bg_depth'] - depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step) - bg_lambda = log_data[m]['bg_lambda'] - bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True, - mask=mask)) - writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step) - - -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - # initialize the process group - torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - - -def cleanup(): - torch.distributed.destroy_process_group() - - -def ddp_train_nerf(rank, args): - ###### set up multi-processing - setup(rank, args.world_size) - ###### set up logger - logger = logging.getLogger(__package__) - setup_logger() - - ###### Create log dir and copy the config file - if rank == 0: - os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) - f = os.path.join(args.basedir, args.expname, 'args.txt') - with open(f, 'w') as file: - for arg in sorted(vars(args)): - attr = getattr(args, arg) - file.write('{} = {}\n'.format(arg, attr)) - if args.config is not None: - f = os.path.join(args.basedir, args.expname, 'config.txt') - with open(f, 'w') as file: - file.write(open(args.config, 'r').read()) - torch.distributed.barrier() - - ###### load data and create ray samplers; each process should do this - # data = load_data(args.datadir, args.scene, args.testskip) - # ray_samplers = [] - # for i in data['i_train']: - # ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i], - # img=data['images'][i], - # img_path=data['paths'][i], - # mask=data['masks'][i], - # min_depth=data['min_depths'][i])) - # - # val_ray_samplers = [] - # for i in data['i_val']: - # val_ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i], - # img=data['images'][i], - # img_path=data['paths'][i], - # mask=data['masks'][i], - # min_depth=data['min_depths'][i])) - # # free memory - # del data - - ray_samplers = load_data_split(args.datadir, args.scene, split='train') - val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation') - - ###### create network and wrap in ddp; each process should do this - # fix random seed just to make sure the network is initialized with same weights at different processes - torch.manual_seed(777) - # very important!!! otherwise it might introduce extra memory in rank=0 gpu - torch.cuda.set_device(rank) - - models = OrderedDict() - models['cascade_level'] = args.cascade_level - models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')] - for m in range(models['cascade_level']): - net = NerfNet(args).to(rank) - net = DDP(net, device_ids=[rank], output_device=rank) - optim = torch.optim.Adam(net.parameters(), lr=args.lrate) - models['net_{}'.format(m)] = net - models['optim_{}'.format(m)] = optim - - start = -1 - - ###### load pretrained weights; each process should do this - if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)): - ckpts = [args.ckpt_path] - else: - ckpts = [os.path.join(args.basedir, args.expname, f) - for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')] - def path2iter(path): - tmp = os.path.basename(path)[:-4] - idx = tmp.rfind('_') - return int(tmp[idx + 1:]) - ckpts = sorted(ckpts, key=path2iter) - logger.info('Found ckpts: {}'.format(ckpts)) - if len(ckpts) > 0 and not args.no_reload: - fpath = ckpts[-1] - logger.info('Reloading from: {}'.format(fpath)) - start = path2iter(fpath) - # configure map_location properly for different processes - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - to_load = torch.load(fpath, map_location=map_location) - for m in range(models['cascade_level']): - for name in ['net_{}'.format(m), 'optim_{}'.format(m)]: - models[name].load_state_dict(to_load[name]) - models[name].load_state_dict(to_load[name]) - - ##### important!!! - # make sure different processes sample different rays - np.random.seed((rank + 1) * 777) - # make sure different processes have different perturbations in depth samples - torch.manual_seed((rank + 1) * 777) - - ##### only main process should do the logging - if rank == 0: - writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname)) - - # start training - what_val_to_log = 0 # helper variable for parallel rendering of a image - what_train_to_log = 0 - for global_step in range(start+1, start+1+args.N_iters): - time0 = time.time() - scalars_to_log = OrderedDict() - ### Start of core optimization loop - scalars_to_log['resolution'] = ray_samplers[0].resolution_level - # randomly sample rays and move to device - i = np.random.randint(low=0, high=len(ray_samplers)) - ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False) - for key in ray_batch: - if torch.is_tensor(ray_batch[key]): - ray_batch[key] = ray_batch[key].to(rank) - - # forward and backward - dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays - all_rets = [] # results on different cascade levels - for m in range(models['cascade_level']): - optim = models['optim_{}'.format(m)] - net = models['net_{}'.format(m)] - - # sample depths - N_samples = models['cascade_samples'][m] - if m == 0: - # foreground depth - fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,] - # fg_near_depth = 0.18 * torch.ones_like(fg_far_depth) - fg_near_depth = ray_batch['min_depth'] # [..., 3] - step = (fg_far_depth - fg_near_depth) / (N_samples - 1) - fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] - fg_depth = perturb_samples(fg_depth) # random perturbation during training - - # background depth - bg_depth = torch.linspace(0., 1., N_samples).view( - [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank) - bg_depth = perturb_samples(bg_depth) # random perturbation during training - else: - # sample pdf and concat with earlier samples - fg_weights = ret['fg_weights'].clone().detach() - fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1] - fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] - fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights, - N_samples=N_samples, det=False) # [..., N_samples] - fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1)) - - # sample pdf and concat with earlier samples - bg_weights = ret['bg_weights'].clone().detach() - bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) - bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] - bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights, - N_samples=N_samples, det=False) # [..., N_samples] - bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1)) - - optim.zero_grad() - ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth) - all_rets.append(ret) - - rgb_gt = ray_batch['rgb'].to(rank) - loss = img2mse(ret['rgb'], rgb_gt) - scalars_to_log['level_{}/loss'.format(m)] = loss.item() - scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item()) - # regularize sigma with photo-consistency - loss = loss + img2mse(ret['diffuse_rgb'], rgb_gt) - loss.backward() - optim.step() - - # # clean unused memory - # torch.cuda.empty_cache() - - ### end of core optimization loop - dt = time.time() - time0 - scalars_to_log['iter_time'] = dt - - ### only main process should do the logging - if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): - logstr = '{} step: {} '.format(args.expname, global_step) - for k in scalars_to_log: - logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) - writer.add_scalar(k, scalars_to_log[k], global_step) - logger.info(logstr) - - ### each process should do this; but only main process merges the results - if global_step % args.i_img == 0 or global_step == start+1: - #### critical: make sure each process is working on the same random image - time0 = time.time() - idx = what_val_to_log % len(val_ray_samplers) - log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size) - what_val_to_log += 1 - dt = time.time() - time0 - if rank == 0: # only main process should do this - logger.info('Logged a random validation view in {} seconds'.format(dt)) - log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].img_orig, mask=None, prefix='val/') - - time0 = time.time() - idx = what_train_to_log % len(ray_samplers) - log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) - what_train_to_log += 1 - dt = time.time() - time0 - if rank == 0: # only main process should do this - logger.info('Logged a random training view in {} seconds'.format(dt)) - log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].img_orig, mask=None, prefix='train/') - - log_data = None - torch.cuda.empty_cache() - - if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0): - # saving checkpoints and logging - fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step)) - to_save = OrderedDict() - for m in range(models['cascade_level']): - name = 'net_{}'.format(m) - to_save[name] = models[name].state_dict() - - name = 'optim_{}'.format(m) - to_save[name] = models[name].state_dict() - torch.save(to_save, fpath) - - # clean up for multi-processing - cleanup() - - -def config_parser(): - import configargparse - parser = configargparse.ArgumentParser() - parser.add_argument('--config', is_config_file=True, help='config file path') - parser.add_argument("--expname", type=str, help='experiment name') - parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') - - # dataset options - parser.add_argument("--datadir", type=str, default=None, help='input data directory') - parser.add_argument("--scene", type=str, default=None, help='scene name') - parser.add_argument("--testskip", type=int, default=8, - help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') - - # model size - parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network') - parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network') - parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') - - # checkpoints - parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') - parser.add_argument("--ckpt_path", type=str, default=None, - help='specific weights npy file to reload for coarse network') - - # batch size - parser.add_argument("--N_rand", type=int, default=32 * 32 * 2, - help='batch size (number of random rays per gradient step)') - parser.add_argument("--chunk_size", type=int, default=1024 * 8, - help='number of rays processed in parallel, decrease if running out of memory') - - # iterations - parser.add_argument("--N_iters", type=int, default=250001, - help='number of iterations') - - # cascade training - parser.add_argument("--cascade_level", type=int, default=2, - help='number of cascade levels') - parser.add_argument("--cascade_samples", type=str, default='64,64', - help='samples at each level') - parser.add_argument("--devices", type=str, default='0,1', - help='cuda device for each level') - parser.add_argument("--bg_devices", type=str, default='0,2', - help='cuda device for the background of each level') - - parser.add_argument("--world_size", type=int, default='-1', - help='number of processes') - - # mixed precison training - parser.add_argument("--opt_level", type=str, default='O1', - help='mixed precison training') - - parser.add_argument("--near_depth", type=float, default=0.1, - help='near depth plane') - parser.add_argument("--far_depth", type=float, default=50., - help='far depth plane') - - # learning rate options - parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') - parser.add_argument("--lrate_decay_factor", type=float, default=0.1, - help='decay learning rate by a factor every specified number of steps') - parser.add_argument("--lrate_decay_steps", type=int, default=5000, - help='decay learning rate by a factor every specified number of steps') - - # rendering options - parser.add_argument("--inv_uniform", action='store_true', - help='if True, will uniformly sample inverse depths') - parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples') - parser.add_argument("--max_freq_log2", type=int, default=10, - help='log2 of max freq for positional encoding (3D location)') - parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4, - help='log2 of max freq for positional encoding (2D direction)') - parser.add_argument("--N_iters_perturb", type=int, default=1000, - help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck') - parser.add_argument("--raw_noise_std", type=float, default=1., - help='std dev of noise added to regularize sigma output, 1e0 recommended') - parser.add_argument("--white_bkgd", action='store_true', - help='apply the trick to avoid fitting to white background') - - # no training; render only - parser.add_argument("--render_only", action='store_true', - help='do not optimize, reload weights and render out render_poses path') - parser.add_argument("--render_train", action='store_true', help='render the training set') - parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') - - # no training; extract mesh only - parser.add_argument("--mesh_only", action='store_true', - help='do not optimize, extract mesh from pretrained model') - parser.add_argument("--N_pts", type=int, default=256, - help='voxel resolution; N_pts * N_pts * N_pts') - parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50', - help='threshold(s) for mesh extraction; can use multiple thresholds') - - # logging/saving options - parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') - parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') - parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') - parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving') - parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') - - return parser - - -def train(): - parser = config_parser() - args = parser.parse_args() - logger.info(parser.format_values()) - - if args.world_size == -1: - args.world_size = torch.cuda.device_count() - logger.info('Using # gpus: {}'.format(args.world_size)) - torch.multiprocessing.spawn(ddp_train_nerf, - args=(args,), - nprocs=args.world_size, - join=True) - - -if __name__ == '__main__': - setup_logger() - train() - - diff --git a/old_scripts/nerf_sample_ray.py b/old_scripts/nerf_sample_ray.py deleted file mode 100644 index d0ad801..0000000 --- a/old_scripts/nerf_sample_ray.py +++ /dev/null @@ -1,231 +0,0 @@ -import numpy as np -from collections import OrderedDict -import torch -import cv2 - - -######################################################################################################################## -# ray batch sampling -######################################################################################################################## - -def parse_camera(params): - H, W = params[:2] - intrinsics = params[2:18].reshape((4, 4)) - c2w = params[18:34].reshape((4, 4)) - - return int(W), int(H), intrinsics.astype(np.float32), c2w.astype(np.float32) - - -def get_rays_single_image(H, W, intrinsics, c2w): - ''' - :param H: image height - :param W: image width - :param intrinsics: 4 by 4 intrinsic matrix - :param c2w: 4 by 4 camera to world extrinsic matrix - :return: - ''' - u, v = np.meshgrid(np.arange(W), np.arange(H)) - - u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel - v = v.reshape(-1).astype(dtype=np.float32) + 0.5 - pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W) - - rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels) - rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W) - rays_d = rays_d.transpose((1, 0)) # (H*W, 3) - - rays_o = c2w[:3, 3].reshape((1, 3)) - rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3) - - depth = np.linalg.inv(c2w)[2, 3] - depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,) - - return rays_o, rays_d, depth - - -class RaySamplerSingleImage(object): - def __init__(self, cam_params, img_path=None, img=None, resolution_level=1, mask=None, min_depth=None): - super().__init__() - self.W_orig, self.H_orig, self.intrinsics_orig, self.c2w_mat = parse_camera(cam_params) - - self.img_path = img_path - self.img_orig = img - self.mask_orig = mask - self.min_depth_orig = min_depth - - self.resolution_level = -1 - self.set_resolution_level(resolution_level) - - def set_resolution_level(self, resolution_level): - if resolution_level != self.resolution_level: - self.resolution_level = resolution_level - self.W = self.W_orig // resolution_level - self.H = self.H_orig // resolution_level - self.intrinsics = np.copy(self.intrinsics_orig) - self.intrinsics[:2, :3] /= resolution_level - if self.img_orig is not None: - self.img = cv2.resize(self.img_orig, (self.W, self.H), interpolation=cv2.INTER_AREA) - self.img = self.img.reshape((-1, 3)) - else: - self.img = None - - if self.mask_orig is not None: - self.mask = cv2.resize(self.mask_orig, (self.W, self.H), interpolation=cv2.INTER_NEAREST) - self.mask = self.mask.reshape((-1)) - else: - self.mask = None - - if self.min_depth_orig is not None: - self.min_depth = cv2.resize(self.min_depth_orig, (self.W, self.H), interpolation=cv2.INTER_LINEAR) - self.min_depth = self.min_depth.reshape((-1)) - else: - self.min_depth = None - - self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W, - self.intrinsics, self.c2w_mat) - - def get_all(self): - if self.min_depth is not None: - min_depth = self.min_depth - else: - min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0]) - - ret = OrderedDict([ - ('ray_o', self.rays_o), - ('ray_d', self.rays_d), - ('depth', self.depth), - ('rgb', self.img), - ('mask', self.mask), - ('min_depth', min_depth) - ]) - # return torch tensors - for k in ret: - if ret[k] is not None: - ret[k] = torch.from_numpy(ret[k]) - return ret - - def random_sample(self, N_rand, center_crop=False): - ''' - :param N_rand: number of rays to be casted - :return: - ''' - if center_crop: - half_H = self.H // 2 - half_W = self.W // 2 - quad_H = half_H // 2 - quad_W = half_W // 2 - - # pixel coordinates - u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W), - np.arange(half_H-quad_H, half_H+quad_H)) - u = u.reshape(-1) - v = v.reshape(-1) - - select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False) - - # Convert back to original image - select_inds = v[select_inds] * self.W + u[select_inds] - else: - # Random from one image - select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False) - - rays_o = self.rays_o[select_inds, :] # [N_rand, 3] - rays_d = self.rays_d[select_inds, :] # [N_rand, 3] - depth = self.depth[select_inds] # [N_rand, ] - - if self.img is not None: - rgb = self.img[select_inds, :] # [N_rand, 3] - else: - rgb = None - - if self.mask is not None: - mask = self.mask[select_inds] - else: - mask = None - - if self.min_depth is not None: - min_depth = self.min_depth[select_inds] - else: - min_depth = 1e-4 * np.ones_like(rays_d[..., 0]) - - ret = OrderedDict([ - ('ray_o', rays_o), - ('ray_d', rays_d), - ('depth', depth), - ('rgb', rgb), - ('mask', mask), - ('min_depth', min_depth) - ]) - # return torch tensors - for k in ret: - if ret[k] is not None: - ret[k] = torch.from_numpy(ret[k]) - - return ret - - # def random_sample_patches(self, N_patch, r_patch=16, center_crop=False): - # ''' - # :param N_patch: number of patches to be sampled - # :param r_patch: patch size will be (2*r_patch+1)*(2*r_patch+1) - # :return: - # ''' - # # even size patch - # # offsets to center pixels - # u, v = np.meshgrid(np.arange(-r_patch, r_patch), - # np.arange(-r_patch, r_patch)) - # u = u.reshape(-1) - # v = v.reshape(-1) - # offsets = v * self.W + u - - # # center pixel coordinates - # u_min = r_patch - # u_max = self.W - r_patch - # v_min = r_patch - # v_max = self.H - r_patch - # if center_crop: - # u_min = self.W // 4 + r_patch - # u_max = self.W - self.W // 4 - r_patch - # v_min = self.H // 4 + r_patch - # v_max = self.H - self.H // 4 - r_patch - - # u, v = np.meshgrid(np.arange(u_min, u_max, r_patch), - # np.arange(v_min, v_max, r_patch)) - # u = u.reshape(-1) - # v = v.reshape(-1) - - # select_inds = np.random.choice(u.shape[0], size=(N_patch,), replace=False) - # # Convert back to original image - # select_inds = v[select_inds] * self.W + u[select_inds] - - # # pick patches - # select_inds = np.stack([select_inds + shift for shift in offsets], axis=1) - # select_inds = select_inds.reshape(-1) - - # rays_o = self.rays_o[select_inds, :] # [N_rand, 3] - # rays_d = self.rays_d[select_inds, :] # [N_rand, 3] - # depth = self.depth[select_inds] # [N_rand, ] - - # if self.img is not None: - # rgb = self.img[select_inds, :] # [N_rand, 3] - - # # ### debug - # # import imageio - # # imgs = rgb.reshape((N_patch, r_patch*2, r_patch*2, -1)) - # # for kk in range(imgs.shape[0]): - # # imageio.imwrite('./debug_{}.png'.format(kk), imgs[kk]) - # # ### - # else: - # rgb = None - - # ret = OrderedDict([ - # ('ray_o', rays_o), - # ('ray_d', rays_d), - # ('depth', depth), - # ('rgb', rgb) - # ]) - - # # return torch tensors - # for k in ret: - # ret[k] = torch.from_numpy(ret[k]) - - # return ret