clean code
This commit is contained in:
parent
5fbe15ff24
commit
e60d030d34
5 changed files with 138 additions and 1029 deletions
138
.gitignore
vendored
Normal file
138
.gitignore
vendored
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
Binary file not shown.
|
@ -1,181 +0,0 @@
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import imageio
|
|
||||||
from collections import OrderedDict
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__package__)
|
|
||||||
|
|
||||||
########################################################################################################################
|
|
||||||
# camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention)
|
|
||||||
# poses is camera-to-world
|
|
||||||
########################################################################################################################
|
|
||||||
|
|
||||||
def load_data(basedir, scene, testskip=8):
|
|
||||||
def parse_txt(filename):
|
|
||||||
assert os.path.isfile(filename)
|
|
||||||
nums = open(filename).read().split()
|
|
||||||
return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
|
|
||||||
|
|
||||||
def dir2poses(posedir):
|
|
||||||
poses = np.stack(
|
|
||||||
[parse_txt(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
|
|
||||||
poses = poses.astype(np.float32)
|
|
||||||
return poses
|
|
||||||
|
|
||||||
def dir2intrinsics(intrinsicdir):
|
|
||||||
intrinsics = np.stack(
|
|
||||||
[parse_txt(os.path.join(intrinsicdir, f)) for f in sorted(os.listdir(intrinsicdir)) if f.endswith('txt')], 0)
|
|
||||||
intrinsics = intrinsics.astype(np.float32)
|
|
||||||
return intrinsics
|
|
||||||
|
|
||||||
intrinsics = dir2intrinsics('{}/{}/train/intrinsics'.format(basedir, scene))
|
|
||||||
testintrinsics = dir2poses('{}/{}/test/intrinsics'.format(basedir, scene))
|
|
||||||
testintrinsics = testintrinsics[::testskip]
|
|
||||||
valintrinsics = dir2poses('{}/{}/validation/intrinsics'.format(basedir, scene))
|
|
||||||
valintrinsics = valintrinsics[::testskip]
|
|
||||||
|
|
||||||
print(intrinsics.shape, testintrinsics.shape, valintrinsics.shape)
|
|
||||||
|
|
||||||
poses = dir2poses('{}/{}/train/pose'.format(basedir, scene))
|
|
||||||
testposes = dir2poses('{}/{}/test/pose'.format(basedir, scene))
|
|
||||||
testposes = testposes[::testskip]
|
|
||||||
valposes = dir2poses('{}/{}/validation/pose'.format(basedir, scene))
|
|
||||||
valposes = valposes[::testskip]
|
|
||||||
|
|
||||||
print(poses.shape, testposes.shape, valposes.shape)
|
|
||||||
|
|
||||||
imgd = '{}/{}/train/rgb'.format(basedir, scene)
|
|
||||||
imgfiles = ['{}/{}'.format(imgd, f)
|
|
||||||
for f in sorted(os.listdir(imgd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
imgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in imgfiles]
|
|
||||||
|
|
||||||
maskd = '{}/{}/train/mask'.format(basedir, scene)
|
|
||||||
if os.path.isdir(maskd):
|
|
||||||
logger.info('Loading mask from: {}'.format(maskd))
|
|
||||||
maskfiles = ['{}/{}'.format(maskd, f)
|
|
||||||
for f in sorted(os.listdir(maskd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
masks = [imageio.imread(f).astype(np.float32) / 255. for f in maskfiles]
|
|
||||||
else:
|
|
||||||
masks = [None for im in imgs]
|
|
||||||
|
|
||||||
# load min_depth map
|
|
||||||
min_depthd = '{}/{}/train/min_depth'.format(basedir, scene)
|
|
||||||
if os.path.isdir(min_depthd):
|
|
||||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
|
||||||
max_depth = float(open('{}/{}/train/max_depth.txt'.format(basedir, scene)).readline().strip())
|
|
||||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
|
||||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
|
||||||
else:
|
|
||||||
min_depths = [None for im in imgs]
|
|
||||||
|
|
||||||
testimgd = '{}/{}/test/rgb'.format(basedir, scene)
|
|
||||||
testimgfiles = ['{}/{}'.format(testimgd, f)
|
|
||||||
for f in sorted(os.listdir(testimgd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
testimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in testimgfiles]
|
|
||||||
testimgfiles = testimgfiles[::testskip]
|
|
||||||
testimgs = testimgs[::testskip]
|
|
||||||
|
|
||||||
testmaskd = '{}/{}/test/mask'.format(basedir, scene)
|
|
||||||
if os.path.isdir(testmaskd):
|
|
||||||
logger.info('Loading mask from: {}'.format(testmaskd))
|
|
||||||
testmaskfiles = ['{}/{}'.format(testmaskd, f)
|
|
||||||
for f in sorted(os.listdir(testmaskd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
testmasks = [imageio.imread(f).astype(np.float32) / 255. for f in testmaskfiles]
|
|
||||||
else:
|
|
||||||
testmasks = [None for im in testimgs]
|
|
||||||
|
|
||||||
# load min_depth map
|
|
||||||
min_depthd = '{}/{}/test/min_depth'.format(basedir, scene)
|
|
||||||
if os.path.isdir(min_depthd):
|
|
||||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
|
||||||
max_depth = float(open('{}/{}/test/max_depth.txt'.format(basedir, scene)).readline().strip())
|
|
||||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
|
||||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
test_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
|
||||||
else:
|
|
||||||
test_min_depths = [None for im in testimgs]
|
|
||||||
|
|
||||||
valimgd = '{}/{}/validation/rgb'.format(basedir, scene)
|
|
||||||
valimgfiles = ['{}/{}'.format(valimgd, f)
|
|
||||||
for f in sorted(os.listdir(valimgd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
valimgs = [imageio.imread(f).astype(np.float32)[..., :3] / 255. for f in valimgfiles]
|
|
||||||
valimgfiles = valimgfiles[::testskip]
|
|
||||||
valimgs = valimgs[::testskip]
|
|
||||||
|
|
||||||
valmaskd = '{}/{}/validation/mask'.format(basedir, scene)
|
|
||||||
if os.path.isdir(valmaskd):
|
|
||||||
logger.info('Loading mask from: {}'.format(valmaskd))
|
|
||||||
valmaskfiles = ['{}/{}'.format(valmaskd, f)
|
|
||||||
for f in sorted(os.listdir(valmaskd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
valmasks = [imageio.imread(f).astype(np.float32) / 255. for f in valmaskfiles]
|
|
||||||
else:
|
|
||||||
valmasks = [None for im in valimgs]
|
|
||||||
|
|
||||||
# load min_depth map
|
|
||||||
min_depthd = '{}/{}/validation/min_depth'.format(basedir, scene)
|
|
||||||
if os.path.isdir(min_depthd):
|
|
||||||
logger.info('Loading min_depth from: {}'.format(min_depthd))
|
|
||||||
max_depth = float(open('{}/{}/validation/max_depth.txt'.format(basedir, scene)).readline().strip())
|
|
||||||
min_depthfiles = ['{}/{}'.format(min_depthd, f)
|
|
||||||
for f in sorted(os.listdir(min_depthd)) if f.endswith('png') or f.endswith('jpg')]
|
|
||||||
val_min_depths = [imageio.imread(f).astype(np.float32) / 255. * max_depth + 1e-4 for f in min_depthfiles]
|
|
||||||
else:
|
|
||||||
val_min_depths = [None for im in valimgs]
|
|
||||||
|
|
||||||
# data format for training/testing
|
|
||||||
print(len(imgs), len(testimgs), len(valimgs))
|
|
||||||
all_imgs = imgs + valimgs + testimgs
|
|
||||||
all_masks = masks + valmasks + testmasks
|
|
||||||
all_min_depths = min_depths + val_min_depths + test_min_depths
|
|
||||||
all_paths = imgfiles + valimgfiles + testimgfiles
|
|
||||||
|
|
||||||
counts = [0] + [len(x) for x in [imgs, valimgs, testimgs]]
|
|
||||||
counts = np.cumsum(counts)
|
|
||||||
i_split = [list(np.arange(counts[i], counts[i+1])) for i in range(3)]
|
|
||||||
|
|
||||||
intrinsics = np.concatenate([intrinsics, valintrinsics, testintrinsics], 0)
|
|
||||||
poses = np.concatenate([poses, valposes, testposes], 0)
|
|
||||||
img_sizes = np.stack([np.array(x.shape[:2]) for x in all_imgs], axis=0) # [H, W]
|
|
||||||
cnt = len(all_imgs)
|
|
||||||
all_cams = np.concatenate((img_sizes.astype(dtype=np.float32), intrinsics.reshape((cnt, -1)), poses.reshape((cnt, -1))), axis=1)
|
|
||||||
|
|
||||||
if os.path.isdir('{}/{}/camera_path/intrinsics'.format(basedir, scene)):
|
|
||||||
camera_path_intrinsics = dir2poses('{}/{}/camera_path/intrinsics'.format(basedir, scene))
|
|
||||||
camera_path_poses = dir2poses('{}/{}/camera_path/pose'.format(basedir, scene))
|
|
||||||
# assume centered principal points
|
|
||||||
# img_sizes = np.stack((camera_path_intrinsics[:, 1, 2]*2, camera_path_intrinsics[:, 0, 2]*2), axis=1) # [H, W]
|
|
||||||
# img_sizes = np.int32(img_sizes)
|
|
||||||
|
|
||||||
H = all_cams[0, 0]
|
|
||||||
W = all_cams[0, 1]
|
|
||||||
img_sizes = np.stack((np.ones_like(camera_path_intrinsics[:, 1, 2])*H, np.ones_like(camera_path_intrinsics[:, 0, 2])*W), axis=1) # [H, W]
|
|
||||||
|
|
||||||
cnt = len(camera_path_intrinsics)
|
|
||||||
render_cams = np.concatenate(
|
|
||||||
(img_sizes.astype(dtype=np.float32), camera_path_intrinsics.reshape((cnt, -1)), camera_path_poses.reshape((cnt, -1))),
|
|
||||||
axis=1)
|
|
||||||
else:
|
|
||||||
render_cams = None
|
|
||||||
|
|
||||||
print(all_cams.shape)
|
|
||||||
|
|
||||||
data = OrderedDict([('images', all_imgs),
|
|
||||||
('masks', all_masks),
|
|
||||||
('paths', all_paths),
|
|
||||||
('min_depths', all_min_depths),
|
|
||||||
('cameras', all_cams),
|
|
||||||
('i_train', i_split[0]),
|
|
||||||
('i_val', i_split[1]),
|
|
||||||
('i_test', i_split[2]),
|
|
||||||
('render_cams', render_cams)])
|
|
||||||
|
|
||||||
logger.info('Data statistics:')
|
|
||||||
logger.info('\t # of training views: {}'.format(len(data['i_train'])))
|
|
||||||
logger.info('\t # of validation views: {}'.format(len(data['i_val'])))
|
|
||||||
logger.info('\t # of test views: {}'.format(len(data['i_test'])))
|
|
||||||
if data['render_cams'] is not None:
|
|
||||||
logger.info('\t # of render cameras: {}'.format(len(data['render_cams'])))
|
|
||||||
|
|
||||||
return data
|
|
|
@ -1,617 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim
|
|
||||||
import torch.distributed
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
import torch.multiprocessing
|
|
||||||
|
|
||||||
import os
|
|
||||||
from collections import OrderedDict
|
|
||||||
from ddp_model import NerfNet
|
|
||||||
import time
|
|
||||||
# from data_loader import load_data
|
|
||||||
# from nerf_sample_ray import RaySamplerSingleImage
|
|
||||||
|
|
||||||
from data_loader_split import load_data_split
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from utils import img2mse, mse2psnr, img_HWC2CHW, colorize, TINY_NUMBER
|
|
||||||
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__package__)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
|
||||||
# create logger
|
|
||||||
logger = logging.getLogger(__package__)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
# create console handler and set level to debug
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
ch.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
# create formatter
|
|
||||||
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
|
||||||
|
|
||||||
# add formatter to ch
|
|
||||||
ch.setFormatter(formatter)
|
|
||||||
|
|
||||||
# add ch to logger
|
|
||||||
logger.addHandler(ch)
|
|
||||||
|
|
||||||
|
|
||||||
def intersect_sphere(ray_o, ray_d):
|
|
||||||
'''
|
|
||||||
ray_o, ray_d: [..., 3]
|
|
||||||
compute the depth of the intersection point between this ray and unit sphere
|
|
||||||
'''
|
|
||||||
# note: d1 becomes negative if this mid point is behind camera
|
|
||||||
d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
|
|
||||||
p = ray_o + d1.unsqueeze(-1) * ray_d
|
|
||||||
# consider the case where the ray does not intersect the sphere
|
|
||||||
ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
|
|
||||||
d2 = torch.sqrt(1. - torch.sum(p * p, dim=-1)) * ray_d_cos
|
|
||||||
|
|
||||||
return d1 + d2
|
|
||||||
|
|
||||||
|
|
||||||
def perturb_samples(z_vals):
|
|
||||||
# get intervals between samples
|
|
||||||
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
|
||||||
upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
|
|
||||||
lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
|
|
||||||
# uniform samples in those intervals
|
|
||||||
t_rand = torch.rand_like(z_vals)
|
|
||||||
z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
|
|
||||||
|
|
||||||
return z_vals
|
|
||||||
|
|
||||||
|
|
||||||
def sample_pdf(bins, weights, N_samples, det=False):
|
|
||||||
'''
|
|
||||||
:param bins: tensor of shape [..., M+1], M is the number of bins
|
|
||||||
:param weights: tensor of shape [..., M]
|
|
||||||
:param N_samples: number of samples along each ray
|
|
||||||
:param det: if True, will perform deterministic sampling
|
|
||||||
:return: [..., N_samples]
|
|
||||||
'''
|
|
||||||
# Get pdf
|
|
||||||
weights = weights + TINY_NUMBER # prevent nans
|
|
||||||
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M]
|
|
||||||
cdf = torch.cumsum(pdf, dim=-1) # [..., M]
|
|
||||||
cdf = torch.cat([torch.zeros_like(cdf[..., 0:1]), cdf], dim=-1) # [..., M+1]
|
|
||||||
|
|
||||||
# Take uniform samples
|
|
||||||
dots_sh = list(weights.shape[:-1])
|
|
||||||
M = weights.shape[-1]
|
|
||||||
|
|
||||||
min_cdf = 0.00
|
|
||||||
max_cdf = 1.00 # prevent outlier samples
|
|
||||||
|
|
||||||
if det:
|
|
||||||
u = torch.linspace(min_cdf, max_cdf, N_samples, device=bins.device)
|
|
||||||
u = u.view([1]*len(dots_sh) + [N_samples]).expand(dots_sh + [N_samples,]) # [..., N_samples]
|
|
||||||
else:
|
|
||||||
sh = dots_sh + [N_samples]
|
|
||||||
u = torch.rand(*sh, device=bins.device) * (max_cdf - min_cdf) + min_cdf # [..., N_samples]
|
|
||||||
|
|
||||||
# Invert CDF
|
|
||||||
# [..., N_samples, 1] >= [..., 1, M] ----> [..., N_samples, M] ----> [..., N_samples,]
|
|
||||||
above_inds = torch.sum(u.unsqueeze(-1) >= cdf[..., :M].unsqueeze(-2), dim=-1).long()
|
|
||||||
|
|
||||||
# random sample inside each bin
|
|
||||||
below_inds = torch.clamp(above_inds-1, min=0)
|
|
||||||
inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
cdf = cdf.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
|
||||||
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
bins = bins.unsqueeze(-2).expand(dots_sh + [N_samples, M+1]) # [..., N_samples, M+1]
|
|
||||||
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [..., N_samples, 2]
|
|
||||||
|
|
||||||
# fix numeric issue
|
|
||||||
denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N_samples]
|
|
||||||
denom = torch.where(denom<TINY_NUMBER, torch.ones_like(denom), denom)
|
|
||||||
t = (u - cdf_g[..., 0]) / denom
|
|
||||||
|
|
||||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_NUMBER)
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
|
|
||||||
##### parallel rendering of a single image
|
|
||||||
ray_batch = ray_sampler.get_all()
|
|
||||||
# split into ranks; make sure different processes don't overlap
|
|
||||||
rank_split_sizes = [ray_batch['ray_d'].shape[0] // world_size, ] * world_size
|
|
||||||
rank_split_sizes[-1] = ray_batch['ray_d'].shape[0] - sum(rank_split_sizes[:-1])
|
|
||||||
for key in ray_batch:
|
|
||||||
if torch.is_tensor(ray_batch[key]):
|
|
||||||
ray_batch[key] = torch.split(ray_batch[key], rank_split_sizes)[rank].to(rank)
|
|
||||||
|
|
||||||
# split into chunks and render inside each process
|
|
||||||
ray_batch_split = OrderedDict()
|
|
||||||
for key in ray_batch:
|
|
||||||
if torch.is_tensor(ray_batch[key]):
|
|
||||||
ray_batch_split[key] = torch.split(ray_batch[key], chunk_size)
|
|
||||||
|
|
||||||
# forward and backward
|
|
||||||
ret_merge_chunk = [OrderedDict() for _ in range(models['cascade_level'])]
|
|
||||||
for s in range(len(ray_batch_split['ray_d'])):
|
|
||||||
ray_o = ray_batch_split['ray_o'][s]
|
|
||||||
ray_d = ray_batch_split['ray_d'][s]
|
|
||||||
min_depth = ray_batch_split['min_depth'][s]
|
|
||||||
|
|
||||||
dots_sh = list(ray_d.shape[:-1])
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
net = models['net_{}'.format(m)]
|
|
||||||
# sample depths
|
|
||||||
N_samples = models['cascade_samples'][m]
|
|
||||||
if m == 0:
|
|
||||||
# foreground depth
|
|
||||||
fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
|
|
||||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
|
||||||
fg_near_depth = min_depth # [..., 3]
|
|
||||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
|
||||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
|
||||||
|
|
||||||
# background depth
|
|
||||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
|
||||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
|
||||||
|
|
||||||
# delete unused memory
|
|
||||||
del fg_near_depth
|
|
||||||
del step
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
fg_weights = ret['fg_weights'].clone().detach()
|
|
||||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
|
||||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
|
||||||
N_samples=N_samples, det=True) # [..., N_samples]
|
|
||||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
bg_weights = ret['bg_weights'].clone().detach()
|
|
||||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
|
||||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
|
||||||
N_samples=N_samples, det=True) # [..., N_samples]
|
|
||||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
# delete unused memory
|
|
||||||
del fg_weights
|
|
||||||
del fg_depth_mid
|
|
||||||
del fg_depth_samples
|
|
||||||
del bg_weights
|
|
||||||
del bg_depth_mid
|
|
||||||
del bg_depth_samples
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ret = net(ray_o, ray_d, fg_far_depth, fg_depth, bg_depth)
|
|
||||||
|
|
||||||
for key in ret:
|
|
||||||
if key not in ['fg_weights', 'bg_weights']:
|
|
||||||
if torch.is_tensor(ret[key]):
|
|
||||||
if key not in ret_merge_chunk[m]:
|
|
||||||
ret_merge_chunk[m][key] = [ret[key].cpu(), ]
|
|
||||||
else:
|
|
||||||
ret_merge_chunk[m][key].append(ret[key].cpu())
|
|
||||||
|
|
||||||
ret[key] = None
|
|
||||||
|
|
||||||
# clean unused memory
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# merge results from different chunks
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
ret_merge_chunk[m][key] = torch.cat(ret_merge_chunk[m][key], dim=0)
|
|
||||||
|
|
||||||
# merge results from different processes
|
|
||||||
if rank == 0:
|
|
||||||
ret_merge_rank = [OrderedDict() for _ in range(len(ret_merge_chunk))]
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
# generate tensors to store results from other processes
|
|
||||||
sh = list(ret_merge_chunk[m][key].shape[1:])
|
|
||||||
ret_merge_rank[m][key] = [torch.zeros(*[size,]+sh, dtype=torch.float32) for size in rank_split_sizes]
|
|
||||||
torch.distributed.gather(ret_merge_chunk[m][key], ret_merge_rank[m][key])
|
|
||||||
ret_merge_rank[m][key] = torch.cat(ret_merge_rank[m][key], dim=0).reshape(
|
|
||||||
(ray_sampler.H, ray_sampler.W, -1)).squeeze()
|
|
||||||
# print(m, key, ret_merge_rank[m][key].shape)
|
|
||||||
else: # send results to main process
|
|
||||||
for m in range(len(ret_merge_chunk)):
|
|
||||||
for key in ret_merge_chunk[m]:
|
|
||||||
torch.distributed.gather(ret_merge_chunk[m][key])
|
|
||||||
|
|
||||||
|
|
||||||
# only rank 0 program returns
|
|
||||||
if rank == 0:
|
|
||||||
return ret_merge_rank
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def log_view_to_tb(writer, global_step, log_data, gt_img, mask, prefix=''):
|
|
||||||
rgb_im = img_HWC2CHW(torch.from_numpy(gt_img))
|
|
||||||
writer.add_image(prefix + 'rgb_gt', rgb_im, global_step)
|
|
||||||
|
|
||||||
for m in range(len(log_data)):
|
|
||||||
rgb_im = img_HWC2CHW(log_data[m]['rgb'])
|
|
||||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
|
||||||
writer.add_image(prefix + 'level_{}/rgb'.format(m), rgb_im, global_step)
|
|
||||||
|
|
||||||
rgb_im = img_HWC2CHW(log_data[m]['fg_rgb'])
|
|
||||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
|
||||||
writer.add_image(prefix + 'level_{}/fg_rgb'.format(m), rgb_im, global_step)
|
|
||||||
depth = log_data[m]['fg_depth']
|
|
||||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
|
||||||
mask=mask))
|
|
||||||
writer.add_image(prefix + 'level_{}/fg_depth'.format(m), depth_im, global_step)
|
|
||||||
|
|
||||||
rgb_im = img_HWC2CHW(log_data[m]['bg_rgb'])
|
|
||||||
rgb_im = torch.clamp(rgb_im, min=0., max=1.) # just in case diffuse+specular>1
|
|
||||||
writer.add_image(prefix + 'level_{}/bg_rgb'.format(m), rgb_im, global_step)
|
|
||||||
depth = log_data[m]['bg_depth']
|
|
||||||
depth_im = img_HWC2CHW(colorize(depth, cmap_name='jet', append_cbar=True,
|
|
||||||
mask=mask))
|
|
||||||
writer.add_image(prefix + 'level_{}/bg_depth'.format(m), depth_im, global_step)
|
|
||||||
bg_lambda = log_data[m]['bg_lambda']
|
|
||||||
bg_lambda_im = img_HWC2CHW(colorize(bg_lambda, cmap_name='hot', append_cbar=True,
|
|
||||||
mask=mask))
|
|
||||||
writer.add_image(prefix + 'level_{}/bg_lambda'.format(m), bg_lambda_im, global_step)
|
|
||||||
|
|
||||||
|
|
||||||
def setup(rank, world_size):
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = '12355'
|
|
||||||
# initialize the process group
|
|
||||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
def ddp_train_nerf(rank, args):
|
|
||||||
###### set up multi-processing
|
|
||||||
setup(rank, args.world_size)
|
|
||||||
###### set up logger
|
|
||||||
logger = logging.getLogger(__package__)
|
|
||||||
setup_logger()
|
|
||||||
|
|
||||||
###### Create log dir and copy the config file
|
|
||||||
if rank == 0:
|
|
||||||
os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True)
|
|
||||||
f = os.path.join(args.basedir, args.expname, 'args.txt')
|
|
||||||
with open(f, 'w') as file:
|
|
||||||
for arg in sorted(vars(args)):
|
|
||||||
attr = getattr(args, arg)
|
|
||||||
file.write('{} = {}\n'.format(arg, attr))
|
|
||||||
if args.config is not None:
|
|
||||||
f = os.path.join(args.basedir, args.expname, 'config.txt')
|
|
||||||
with open(f, 'w') as file:
|
|
||||||
file.write(open(args.config, 'r').read())
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
###### load data and create ray samplers; each process should do this
|
|
||||||
# data = load_data(args.datadir, args.scene, args.testskip)
|
|
||||||
# ray_samplers = []
|
|
||||||
# for i in data['i_train']:
|
|
||||||
# ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i],
|
|
||||||
# img=data['images'][i],
|
|
||||||
# img_path=data['paths'][i],
|
|
||||||
# mask=data['masks'][i],
|
|
||||||
# min_depth=data['min_depths'][i]))
|
|
||||||
#
|
|
||||||
# val_ray_samplers = []
|
|
||||||
# for i in data['i_val']:
|
|
||||||
# val_ray_samplers.append(RaySamplerSingleImage(cam_params=data['cameras'][i],
|
|
||||||
# img=data['images'][i],
|
|
||||||
# img_path=data['paths'][i],
|
|
||||||
# mask=data['masks'][i],
|
|
||||||
# min_depth=data['min_depths'][i]))
|
|
||||||
# # free memory
|
|
||||||
# del data
|
|
||||||
|
|
||||||
ray_samplers = load_data_split(args.datadir, args.scene, split='train')
|
|
||||||
val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation')
|
|
||||||
|
|
||||||
###### create network and wrap in ddp; each process should do this
|
|
||||||
# fix random seed just to make sure the network is initialized with same weights at different processes
|
|
||||||
torch.manual_seed(777)
|
|
||||||
# very important!!! otherwise it might introduce extra memory in rank=0 gpu
|
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
|
|
||||||
models = OrderedDict()
|
|
||||||
models['cascade_level'] = args.cascade_level
|
|
||||||
models['cascade_samples'] = [int(x.strip()) for x in args.cascade_samples.split(',')]
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
net = NerfNet(args).to(rank)
|
|
||||||
net = DDP(net, device_ids=[rank], output_device=rank)
|
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lrate)
|
|
||||||
models['net_{}'.format(m)] = net
|
|
||||||
models['optim_{}'.format(m)] = optim
|
|
||||||
|
|
||||||
start = -1
|
|
||||||
|
|
||||||
###### load pretrained weights; each process should do this
|
|
||||||
if (args.ckpt_path is not None) and (os.path.isfile(args.ckpt_path)):
|
|
||||||
ckpts = [args.ckpt_path]
|
|
||||||
else:
|
|
||||||
ckpts = [os.path.join(args.basedir, args.expname, f)
|
|
||||||
for f in sorted(os.listdir(os.path.join(args.basedir, args.expname))) if f.endswith('.pth')]
|
|
||||||
def path2iter(path):
|
|
||||||
tmp = os.path.basename(path)[:-4]
|
|
||||||
idx = tmp.rfind('_')
|
|
||||||
return int(tmp[idx + 1:])
|
|
||||||
ckpts = sorted(ckpts, key=path2iter)
|
|
||||||
logger.info('Found ckpts: {}'.format(ckpts))
|
|
||||||
if len(ckpts) > 0 and not args.no_reload:
|
|
||||||
fpath = ckpts[-1]
|
|
||||||
logger.info('Reloading from: {}'.format(fpath))
|
|
||||||
start = path2iter(fpath)
|
|
||||||
# configure map_location properly for different processes
|
|
||||||
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
|
||||||
to_load = torch.load(fpath, map_location=map_location)
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
for name in ['net_{}'.format(m), 'optim_{}'.format(m)]:
|
|
||||||
models[name].load_state_dict(to_load[name])
|
|
||||||
models[name].load_state_dict(to_load[name])
|
|
||||||
|
|
||||||
##### important!!!
|
|
||||||
# make sure different processes sample different rays
|
|
||||||
np.random.seed((rank + 1) * 777)
|
|
||||||
# make sure different processes have different perturbations in depth samples
|
|
||||||
torch.manual_seed((rank + 1) * 777)
|
|
||||||
|
|
||||||
##### only main process should do the logging
|
|
||||||
if rank == 0:
|
|
||||||
writer = SummaryWriter(os.path.join(args.basedir, 'summaries', args.expname))
|
|
||||||
|
|
||||||
# start training
|
|
||||||
what_val_to_log = 0 # helper variable for parallel rendering of a image
|
|
||||||
what_train_to_log = 0
|
|
||||||
for global_step in range(start+1, start+1+args.N_iters):
|
|
||||||
time0 = time.time()
|
|
||||||
scalars_to_log = OrderedDict()
|
|
||||||
### Start of core optimization loop
|
|
||||||
scalars_to_log['resolution'] = ray_samplers[0].resolution_level
|
|
||||||
# randomly sample rays and move to device
|
|
||||||
i = np.random.randint(low=0, high=len(ray_samplers))
|
|
||||||
ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False)
|
|
||||||
for key in ray_batch:
|
|
||||||
if torch.is_tensor(ray_batch[key]):
|
|
||||||
ray_batch[key] = ray_batch[key].to(rank)
|
|
||||||
|
|
||||||
# forward and backward
|
|
||||||
dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays
|
|
||||||
all_rets = [] # results on different cascade levels
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
optim = models['optim_{}'.format(m)]
|
|
||||||
net = models['net_{}'.format(m)]
|
|
||||||
|
|
||||||
# sample depths
|
|
||||||
N_samples = models['cascade_samples'][m]
|
|
||||||
if m == 0:
|
|
||||||
# foreground depth
|
|
||||||
fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,]
|
|
||||||
# fg_near_depth = 0.18 * torch.ones_like(fg_far_depth)
|
|
||||||
fg_near_depth = ray_batch['min_depth'] # [..., 3]
|
|
||||||
step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
|
|
||||||
fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
|
|
||||||
fg_depth = perturb_samples(fg_depth) # random perturbation during training
|
|
||||||
|
|
||||||
# background depth
|
|
||||||
bg_depth = torch.linspace(0., 1., N_samples).view(
|
|
||||||
[1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(rank)
|
|
||||||
bg_depth = perturb_samples(bg_depth) # random perturbation during training
|
|
||||||
else:
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
fg_weights = ret['fg_weights'].clone().detach()
|
|
||||||
fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1]) # [..., N_samples-1]
|
|
||||||
fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights,
|
|
||||||
N_samples=N_samples, det=False) # [..., N_samples]
|
|
||||||
fg_depth, _ = torch.sort(torch.cat((fg_depth, fg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
# sample pdf and concat with earlier samples
|
|
||||||
bg_weights = ret['bg_weights'].clone().detach()
|
|
||||||
bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1])
|
|
||||||
bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2]
|
|
||||||
bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights,
|
|
||||||
N_samples=N_samples, det=False) # [..., N_samples]
|
|
||||||
bg_depth, _ = torch.sort(torch.cat((bg_depth, bg_depth_samples), dim=-1))
|
|
||||||
|
|
||||||
optim.zero_grad()
|
|
||||||
ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth)
|
|
||||||
all_rets.append(ret)
|
|
||||||
|
|
||||||
rgb_gt = ray_batch['rgb'].to(rank)
|
|
||||||
loss = img2mse(ret['rgb'], rgb_gt)
|
|
||||||
scalars_to_log['level_{}/loss'.format(m)] = loss.item()
|
|
||||||
scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr(loss.item())
|
|
||||||
# regularize sigma with photo-consistency
|
|
||||||
loss = loss + img2mse(ret['diffuse_rgb'], rgb_gt)
|
|
||||||
loss.backward()
|
|
||||||
optim.step()
|
|
||||||
|
|
||||||
# # clean unused memory
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
### end of core optimization loop
|
|
||||||
dt = time.time() - time0
|
|
||||||
scalars_to_log['iter_time'] = dt
|
|
||||||
|
|
||||||
### only main process should do the logging
|
|
||||||
if rank == 0 and (global_step % args.i_print == 0 or global_step < 10):
|
|
||||||
logstr = '{} step: {} '.format(args.expname, global_step)
|
|
||||||
for k in scalars_to_log:
|
|
||||||
logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k])
|
|
||||||
writer.add_scalar(k, scalars_to_log[k], global_step)
|
|
||||||
logger.info(logstr)
|
|
||||||
|
|
||||||
### each process should do this; but only main process merges the results
|
|
||||||
if global_step % args.i_img == 0 or global_step == start+1:
|
|
||||||
#### critical: make sure each process is working on the same random image
|
|
||||||
time0 = time.time()
|
|
||||||
idx = what_val_to_log % len(val_ray_samplers)
|
|
||||||
log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size)
|
|
||||||
what_val_to_log += 1
|
|
||||||
dt = time.time() - time0
|
|
||||||
if rank == 0: # only main process should do this
|
|
||||||
logger.info('Logged a random validation view in {} seconds'.format(dt))
|
|
||||||
log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].img_orig, mask=None, prefix='val/')
|
|
||||||
|
|
||||||
time0 = time.time()
|
|
||||||
idx = what_train_to_log % len(ray_samplers)
|
|
||||||
log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size)
|
|
||||||
what_train_to_log += 1
|
|
||||||
dt = time.time() - time0
|
|
||||||
if rank == 0: # only main process should do this
|
|
||||||
logger.info('Logged a random training view in {} seconds'.format(dt))
|
|
||||||
log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].img_orig, mask=None, prefix='train/')
|
|
||||||
|
|
||||||
log_data = None
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0):
|
|
||||||
# saving checkpoints and logging
|
|
||||||
fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step))
|
|
||||||
to_save = OrderedDict()
|
|
||||||
for m in range(models['cascade_level']):
|
|
||||||
name = 'net_{}'.format(m)
|
|
||||||
to_save[name] = models[name].state_dict()
|
|
||||||
|
|
||||||
name = 'optim_{}'.format(m)
|
|
||||||
to_save[name] = models[name].state_dict()
|
|
||||||
torch.save(to_save, fpath)
|
|
||||||
|
|
||||||
# clean up for multi-processing
|
|
||||||
cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
def config_parser():
|
|
||||||
import configargparse
|
|
||||||
parser = configargparse.ArgumentParser()
|
|
||||||
parser.add_argument('--config', is_config_file=True, help='config file path')
|
|
||||||
parser.add_argument("--expname", type=str, help='experiment name')
|
|
||||||
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
|
|
||||||
|
|
||||||
# dataset options
|
|
||||||
parser.add_argument("--datadir", type=str, default=None, help='input data directory')
|
|
||||||
parser.add_argument("--scene", type=str, default=None, help='scene name')
|
|
||||||
parser.add_argument("--testskip", type=int, default=8,
|
|
||||||
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
|
||||||
|
|
||||||
# model size
|
|
||||||
parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
|
|
||||||
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
|
|
||||||
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
|
|
||||||
|
|
||||||
# checkpoints
|
|
||||||
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
|
|
||||||
parser.add_argument("--ckpt_path", type=str, default=None,
|
|
||||||
help='specific weights npy file to reload for coarse network')
|
|
||||||
|
|
||||||
# batch size
|
|
||||||
parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
|
|
||||||
help='batch size (number of random rays per gradient step)')
|
|
||||||
parser.add_argument("--chunk_size", type=int, default=1024 * 8,
|
|
||||||
help='number of rays processed in parallel, decrease if running out of memory')
|
|
||||||
|
|
||||||
# iterations
|
|
||||||
parser.add_argument("--N_iters", type=int, default=250001,
|
|
||||||
help='number of iterations')
|
|
||||||
|
|
||||||
# cascade training
|
|
||||||
parser.add_argument("--cascade_level", type=int, default=2,
|
|
||||||
help='number of cascade levels')
|
|
||||||
parser.add_argument("--cascade_samples", type=str, default='64,64',
|
|
||||||
help='samples at each level')
|
|
||||||
parser.add_argument("--devices", type=str, default='0,1',
|
|
||||||
help='cuda device for each level')
|
|
||||||
parser.add_argument("--bg_devices", type=str, default='0,2',
|
|
||||||
help='cuda device for the background of each level')
|
|
||||||
|
|
||||||
parser.add_argument("--world_size", type=int, default='-1',
|
|
||||||
help='number of processes')
|
|
||||||
|
|
||||||
# mixed precison training
|
|
||||||
parser.add_argument("--opt_level", type=str, default='O1',
|
|
||||||
help='mixed precison training')
|
|
||||||
|
|
||||||
parser.add_argument("--near_depth", type=float, default=0.1,
|
|
||||||
help='near depth plane')
|
|
||||||
parser.add_argument("--far_depth", type=float, default=50.,
|
|
||||||
help='far depth plane')
|
|
||||||
|
|
||||||
# learning rate options
|
|
||||||
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
|
|
||||||
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
|
|
||||||
help='decay learning rate by a factor every specified number of steps')
|
|
||||||
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
|
|
||||||
help='decay learning rate by a factor every specified number of steps')
|
|
||||||
|
|
||||||
# rendering options
|
|
||||||
parser.add_argument("--inv_uniform", action='store_true',
|
|
||||||
help='if True, will uniformly sample inverse depths')
|
|
||||||
parser.add_argument("--det", action='store_true', help='deterministic sampling for coarse and fine samples')
|
|
||||||
parser.add_argument("--max_freq_log2", type=int, default=10,
|
|
||||||
help='log2 of max freq for positional encoding (3D location)')
|
|
||||||
parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
|
|
||||||
help='log2 of max freq for positional encoding (2D direction)')
|
|
||||||
parser.add_argument("--N_iters_perturb", type=int, default=1000,
|
|
||||||
help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck')
|
|
||||||
parser.add_argument("--raw_noise_std", type=float, default=1.,
|
|
||||||
help='std dev of noise added to regularize sigma output, 1e0 recommended')
|
|
||||||
parser.add_argument("--white_bkgd", action='store_true',
|
|
||||||
help='apply the trick to avoid fitting to white background')
|
|
||||||
|
|
||||||
# no training; render only
|
|
||||||
parser.add_argument("--render_only", action='store_true',
|
|
||||||
help='do not optimize, reload weights and render out render_poses path')
|
|
||||||
parser.add_argument("--render_train", action='store_true', help='render the training set')
|
|
||||||
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
|
|
||||||
|
|
||||||
# no training; extract mesh only
|
|
||||||
parser.add_argument("--mesh_only", action='store_true',
|
|
||||||
help='do not optimize, extract mesh from pretrained model')
|
|
||||||
parser.add_argument("--N_pts", type=int, default=256,
|
|
||||||
help='voxel resolution; N_pts * N_pts * N_pts')
|
|
||||||
parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50',
|
|
||||||
help='threshold(s) for mesh extraction; can use multiple thresholds')
|
|
||||||
|
|
||||||
# logging/saving options
|
|
||||||
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
|
|
||||||
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
|
|
||||||
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
|
|
||||||
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
|
|
||||||
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
|
||||||
parser = config_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.info(parser.format_values())
|
|
||||||
|
|
||||||
if args.world_size == -1:
|
|
||||||
args.world_size = torch.cuda.device_count()
|
|
||||||
logger.info('Using # gpus: {}'.format(args.world_size))
|
|
||||||
torch.multiprocessing.spawn(ddp_train_nerf,
|
|
||||||
args=(args,),
|
|
||||||
nprocs=args.world_size,
|
|
||||||
join=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
setup_logger()
|
|
||||||
train()
|
|
||||||
|
|
||||||
|
|
|
@ -1,231 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from collections import OrderedDict
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################################
|
|
||||||
# ray batch sampling
|
|
||||||
########################################################################################################################
|
|
||||||
|
|
||||||
def parse_camera(params):
|
|
||||||
H, W = params[:2]
|
|
||||||
intrinsics = params[2:18].reshape((4, 4))
|
|
||||||
c2w = params[18:34].reshape((4, 4))
|
|
||||||
|
|
||||||
return int(W), int(H), intrinsics.astype(np.float32), c2w.astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rays_single_image(H, W, intrinsics, c2w):
|
|
||||||
'''
|
|
||||||
:param H: image height
|
|
||||||
:param W: image width
|
|
||||||
:param intrinsics: 4 by 4 intrinsic matrix
|
|
||||||
:param c2w: 4 by 4 camera to world extrinsic matrix
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
|
||||||
|
|
||||||
u = u.reshape(-1).astype(dtype=np.float32) + 0.5 # add half pixel
|
|
||||||
v = v.reshape(-1).astype(dtype=np.float32) + 0.5
|
|
||||||
pixels = np.stack((u, v, np.ones_like(u)), axis=0) # (3, H*W)
|
|
||||||
|
|
||||||
rays_d = np.dot(np.linalg.inv(intrinsics[:3, :3]), pixels)
|
|
||||||
rays_d = np.dot(c2w[:3, :3], rays_d) # (3, H*W)
|
|
||||||
rays_d = rays_d.transpose((1, 0)) # (H*W, 3)
|
|
||||||
|
|
||||||
rays_o = c2w[:3, 3].reshape((1, 3))
|
|
||||||
rays_o = np.tile(rays_o, (rays_d.shape[0], 1)) # (H*W, 3)
|
|
||||||
|
|
||||||
depth = np.linalg.inv(c2w)[2, 3]
|
|
||||||
depth = depth * np.ones((rays_o.shape[0],), dtype=np.float32) # (H*W,)
|
|
||||||
|
|
||||||
return rays_o, rays_d, depth
|
|
||||||
|
|
||||||
|
|
||||||
class RaySamplerSingleImage(object):
|
|
||||||
def __init__(self, cam_params, img_path=None, img=None, resolution_level=1, mask=None, min_depth=None):
|
|
||||||
super().__init__()
|
|
||||||
self.W_orig, self.H_orig, self.intrinsics_orig, self.c2w_mat = parse_camera(cam_params)
|
|
||||||
|
|
||||||
self.img_path = img_path
|
|
||||||
self.img_orig = img
|
|
||||||
self.mask_orig = mask
|
|
||||||
self.min_depth_orig = min_depth
|
|
||||||
|
|
||||||
self.resolution_level = -1
|
|
||||||
self.set_resolution_level(resolution_level)
|
|
||||||
|
|
||||||
def set_resolution_level(self, resolution_level):
|
|
||||||
if resolution_level != self.resolution_level:
|
|
||||||
self.resolution_level = resolution_level
|
|
||||||
self.W = self.W_orig // resolution_level
|
|
||||||
self.H = self.H_orig // resolution_level
|
|
||||||
self.intrinsics = np.copy(self.intrinsics_orig)
|
|
||||||
self.intrinsics[:2, :3] /= resolution_level
|
|
||||||
if self.img_orig is not None:
|
|
||||||
self.img = cv2.resize(self.img_orig, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
|
||||||
self.img = self.img.reshape((-1, 3))
|
|
||||||
else:
|
|
||||||
self.img = None
|
|
||||||
|
|
||||||
if self.mask_orig is not None:
|
|
||||||
self.mask = cv2.resize(self.mask_orig, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
|
|
||||||
self.mask = self.mask.reshape((-1))
|
|
||||||
else:
|
|
||||||
self.mask = None
|
|
||||||
|
|
||||||
if self.min_depth_orig is not None:
|
|
||||||
self.min_depth = cv2.resize(self.min_depth_orig, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
|
||||||
self.min_depth = self.min_depth.reshape((-1))
|
|
||||||
else:
|
|
||||||
self.min_depth = None
|
|
||||||
|
|
||||||
self.rays_o, self.rays_d, self.depth = get_rays_single_image(self.H, self.W,
|
|
||||||
self.intrinsics, self.c2w_mat)
|
|
||||||
|
|
||||||
def get_all(self):
|
|
||||||
if self.min_depth is not None:
|
|
||||||
min_depth = self.min_depth
|
|
||||||
else:
|
|
||||||
min_depth = 1e-4 * np.ones_like(self.rays_d[..., 0])
|
|
||||||
|
|
||||||
ret = OrderedDict([
|
|
||||||
('ray_o', self.rays_o),
|
|
||||||
('ray_d', self.rays_d),
|
|
||||||
('depth', self.depth),
|
|
||||||
('rgb', self.img),
|
|
||||||
('mask', self.mask),
|
|
||||||
('min_depth', min_depth)
|
|
||||||
])
|
|
||||||
# return torch tensors
|
|
||||||
for k in ret:
|
|
||||||
if ret[k] is not None:
|
|
||||||
ret[k] = torch.from_numpy(ret[k])
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def random_sample(self, N_rand, center_crop=False):
|
|
||||||
'''
|
|
||||||
:param N_rand: number of rays to be casted
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
if center_crop:
|
|
||||||
half_H = self.H // 2
|
|
||||||
half_W = self.W // 2
|
|
||||||
quad_H = half_H // 2
|
|
||||||
quad_W = half_W // 2
|
|
||||||
|
|
||||||
# pixel coordinates
|
|
||||||
u, v = np.meshgrid(np.arange(half_W-quad_W, half_W+quad_W),
|
|
||||||
np.arange(half_H-quad_H, half_H+quad_H))
|
|
||||||
u = u.reshape(-1)
|
|
||||||
v = v.reshape(-1)
|
|
||||||
|
|
||||||
select_inds = np.random.choice(u.shape[0], size=(N_rand,), replace=False)
|
|
||||||
|
|
||||||
# Convert back to original image
|
|
||||||
select_inds = v[select_inds] * self.W + u[select_inds]
|
|
||||||
else:
|
|
||||||
# Random from one image
|
|
||||||
select_inds = np.random.choice(self.H*self.W, size=(N_rand,), replace=False)
|
|
||||||
|
|
||||||
rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
|
||||||
rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
|
||||||
depth = self.depth[select_inds] # [N_rand, ]
|
|
||||||
|
|
||||||
if self.img is not None:
|
|
||||||
rgb = self.img[select_inds, :] # [N_rand, 3]
|
|
||||||
else:
|
|
||||||
rgb = None
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
mask = self.mask[select_inds]
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
if self.min_depth is not None:
|
|
||||||
min_depth = self.min_depth[select_inds]
|
|
||||||
else:
|
|
||||||
min_depth = 1e-4 * np.ones_like(rays_d[..., 0])
|
|
||||||
|
|
||||||
ret = OrderedDict([
|
|
||||||
('ray_o', rays_o),
|
|
||||||
('ray_d', rays_d),
|
|
||||||
('depth', depth),
|
|
||||||
('rgb', rgb),
|
|
||||||
('mask', mask),
|
|
||||||
('min_depth', min_depth)
|
|
||||||
])
|
|
||||||
# return torch tensors
|
|
||||||
for k in ret:
|
|
||||||
if ret[k] is not None:
|
|
||||||
ret[k] = torch.from_numpy(ret[k])
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
# def random_sample_patches(self, N_patch, r_patch=16, center_crop=False):
|
|
||||||
# '''
|
|
||||||
# :param N_patch: number of patches to be sampled
|
|
||||||
# :param r_patch: patch size will be (2*r_patch+1)*(2*r_patch+1)
|
|
||||||
# :return:
|
|
||||||
# '''
|
|
||||||
# # even size patch
|
|
||||||
# # offsets to center pixels
|
|
||||||
# u, v = np.meshgrid(np.arange(-r_patch, r_patch),
|
|
||||||
# np.arange(-r_patch, r_patch))
|
|
||||||
# u = u.reshape(-1)
|
|
||||||
# v = v.reshape(-1)
|
|
||||||
# offsets = v * self.W + u
|
|
||||||
|
|
||||||
# # center pixel coordinates
|
|
||||||
# u_min = r_patch
|
|
||||||
# u_max = self.W - r_patch
|
|
||||||
# v_min = r_patch
|
|
||||||
# v_max = self.H - r_patch
|
|
||||||
# if center_crop:
|
|
||||||
# u_min = self.W // 4 + r_patch
|
|
||||||
# u_max = self.W - self.W // 4 - r_patch
|
|
||||||
# v_min = self.H // 4 + r_patch
|
|
||||||
# v_max = self.H - self.H // 4 - r_patch
|
|
||||||
|
|
||||||
# u, v = np.meshgrid(np.arange(u_min, u_max, r_patch),
|
|
||||||
# np.arange(v_min, v_max, r_patch))
|
|
||||||
# u = u.reshape(-1)
|
|
||||||
# v = v.reshape(-1)
|
|
||||||
|
|
||||||
# select_inds = np.random.choice(u.shape[0], size=(N_patch,), replace=False)
|
|
||||||
# # Convert back to original image
|
|
||||||
# select_inds = v[select_inds] * self.W + u[select_inds]
|
|
||||||
|
|
||||||
# # pick patches
|
|
||||||
# select_inds = np.stack([select_inds + shift for shift in offsets], axis=1)
|
|
||||||
# select_inds = select_inds.reshape(-1)
|
|
||||||
|
|
||||||
# rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
|
|
||||||
# rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
|
|
||||||
# depth = self.depth[select_inds] # [N_rand, ]
|
|
||||||
|
|
||||||
# if self.img is not None:
|
|
||||||
# rgb = self.img[select_inds, :] # [N_rand, 3]
|
|
||||||
|
|
||||||
# # ### debug
|
|
||||||
# # import imageio
|
|
||||||
# # imgs = rgb.reshape((N_patch, r_patch*2, r_patch*2, -1))
|
|
||||||
# # for kk in range(imgs.shape[0]):
|
|
||||||
# # imageio.imwrite('./debug_{}.png'.format(kk), imgs[kk])
|
|
||||||
# # ###
|
|
||||||
# else:
|
|
||||||
# rgb = None
|
|
||||||
|
|
||||||
# ret = OrderedDict([
|
|
||||||
# ('ray_o', rays_o),
|
|
||||||
# ('ray_d', rays_d),
|
|
||||||
# ('depth', depth),
|
|
||||||
# ('rgb', rgb)
|
|
||||||
# ])
|
|
||||||
|
|
||||||
# # return torch tensors
|
|
||||||
# for k in ret:
|
|
||||||
# ret[k] = torch.from_numpy(ret[k])
|
|
||||||
|
|
||||||
# return ret
|
|
Loading…
Reference in a new issue