Add doc string
This commit is contained in:
parent
f61ca730eb
commit
c3ccc0bdb3
1 changed files with 155 additions and 56 deletions
211
run_nerf.py
211
run_nerf.py
|
@ -25,6 +25,8 @@ DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
def batchify(fn, chunk):
|
def batchify(fn, chunk):
|
||||||
|
"""Constructs a version of 'fn' that applies to smaller batches.
|
||||||
|
"""
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return fn
|
return fn
|
||||||
def ret(inputs):
|
def ret(inputs):
|
||||||
|
@ -33,7 +35,8 @@ def batchify(fn, chunk):
|
||||||
|
|
||||||
|
|
||||||
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
|
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
|
||||||
|
"""Prepares inputs and applies network 'fn'.
|
||||||
|
"""
|
||||||
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
|
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
|
||||||
embedded = embed_fn(inputs_flat)
|
embedded = embed_fn(inputs_flat)
|
||||||
|
|
||||||
|
@ -49,7 +52,8 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
|
||||||
|
|
||||||
|
|
||||||
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
|
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
|
||||||
|
"""Render rays in smaller minibatches to avoid OOM.
|
||||||
|
"""
|
||||||
all_ret = {}
|
all_ret = {}
|
||||||
for i in range(0, rays_flat.shape[0], chunk):
|
for i in range(0, rays_flat.shape[0], chunk):
|
||||||
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
|
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
|
||||||
|
@ -66,7 +70,28 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
|
||||||
near=0., far=1.,
|
near=0., far=1.,
|
||||||
use_viewdirs=False, c2w_staticcam=None,
|
use_viewdirs=False, c2w_staticcam=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
"""Render rays
|
||||||
|
Args:
|
||||||
|
H: int. Height of image in pixels.
|
||||||
|
W: int. Width of image in pixels.
|
||||||
|
focal: float. Focal length of pinhole camera.
|
||||||
|
chunk: int. Maximum number of rays to process simultaneously. Used to
|
||||||
|
control maximum memory usage. Does not affect final results.
|
||||||
|
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
|
||||||
|
each example in batch.
|
||||||
|
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
|
||||||
|
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
|
||||||
|
near: float or array of shape [batch_size]. Nearest distance for a ray.
|
||||||
|
far: float or array of shape [batch_size]. Farthest distance for a ray.
|
||||||
|
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
|
||||||
|
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
|
||||||
|
camera while using other c2w argument for viewing directions.
|
||||||
|
Returns:
|
||||||
|
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
|
||||||
|
disp_map: [batch_size]. Disparity map. Inverse of depth.
|
||||||
|
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
|
||||||
|
extras: dict with everything returned by render_rays().
|
||||||
|
"""
|
||||||
if c2w is not None:
|
if c2w is not None:
|
||||||
# special case to render full image
|
# special case to render full image
|
||||||
rays_o, rays_d = get_rays(H, W, focal, c2w)
|
rays_o, rays_d = get_rays(H, W, focal, c2w)
|
||||||
|
@ -151,6 +176,8 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
|
||||||
|
|
||||||
|
|
||||||
def create_nerf(args):
|
def create_nerf(args):
|
||||||
|
"""Instantiate NeRF's MLP model.
|
||||||
|
"""
|
||||||
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
|
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
|
||||||
|
|
||||||
input_ch_views = 0
|
input_ch_views = 0
|
||||||
|
@ -233,7 +260,17 @@ def create_nerf(args):
|
||||||
|
|
||||||
|
|
||||||
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
|
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
|
||||||
""" A helper function for `render_rays`.
|
"""Transforms model's predictions to semantically meaningful values.
|
||||||
|
Args:
|
||||||
|
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
|
||||||
|
z_vals: [num_rays, num_samples along ray]. Integration time.
|
||||||
|
rays_d: [num_rays, 3]. Direction of each ray.
|
||||||
|
Returns:
|
||||||
|
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
|
||||||
|
disp_map: [num_rays]. Disparity map. Inverse of depth map.
|
||||||
|
acc_map: [num_rays]. Sum of weights along each ray.
|
||||||
|
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
|
||||||
|
depth_map: [num_rays]. Estimated distance to object.
|
||||||
"""
|
"""
|
||||||
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
|
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
|
||||||
|
|
||||||
|
@ -281,6 +318,36 @@ def render_rays(ray_batch,
|
||||||
raw_noise_std=0.,
|
raw_noise_std=0.,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
pytest=False):
|
pytest=False):
|
||||||
|
"""Volumetric rendering.
|
||||||
|
Args:
|
||||||
|
ray_batch: array of shape [batch_size, ...]. All information necessary
|
||||||
|
for sampling along a ray, including: ray origin, ray direction, min
|
||||||
|
dist, max dist, and unit-magnitude viewing direction.
|
||||||
|
network_fn: function. Model for predicting RGB and density at each point
|
||||||
|
in space.
|
||||||
|
network_query_fn: function used for passing queries to network_fn.
|
||||||
|
N_samples: int. Number of different times to sample along each ray.
|
||||||
|
retraw: bool. If True, include model's raw, unprocessed predictions.
|
||||||
|
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
|
||||||
|
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
|
||||||
|
random points in time.
|
||||||
|
N_importance: int. Number of additional times to sample along each ray.
|
||||||
|
These samples are only passed to network_fine.
|
||||||
|
network_fine: "fine" network with same spec as network_fn.
|
||||||
|
white_bkgd: bool. If True, assume a white background.
|
||||||
|
raw_noise_std: ...
|
||||||
|
verbose: bool. If True, print more debugging info.
|
||||||
|
Returns:
|
||||||
|
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
|
||||||
|
disp_map: [num_rays]. Disparity map. 1 / depth.
|
||||||
|
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
|
||||||
|
raw: [num_rays, num_samples, 4]. Raw predictions from model.
|
||||||
|
rgb0: See rgb_map. Output for coarse model.
|
||||||
|
disp0: See disp_map. Output for coarse model.
|
||||||
|
acc0: See acc_map. Output for coarse model.
|
||||||
|
z_std: [num_rays]. Standard deviation of distances along ray for each
|
||||||
|
sample.
|
||||||
|
"""
|
||||||
N_rays = ray_batch.shape[0]
|
N_rays = ray_batch.shape[0]
|
||||||
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
|
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
|
||||||
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
|
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
|
||||||
|
@ -355,74 +422,114 @@ def config_parser():
|
||||||
|
|
||||||
import configargparse
|
import configargparse
|
||||||
parser = configargparse.ArgumentParser()
|
parser = configargparse.ArgumentParser()
|
||||||
parser.add_argument('--config', is_config_file=True, help='config file path')
|
parser.add_argument('--config', is_config_file=True,
|
||||||
parser.add_argument("--expname", type=str, help='experiment name')
|
help='config file path')
|
||||||
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
|
parser.add_argument("--expname", type=str,
|
||||||
parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
|
help='experiment name')
|
||||||
|
parser.add_argument("--basedir", type=str, default='./logs/',
|
||||||
|
help='where to store ckpts and logs')
|
||||||
|
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
|
||||||
|
help='input data directory')
|
||||||
|
|
||||||
# training options
|
# training options
|
||||||
parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
|
parser.add_argument("--netdepth", type=int, default=8,
|
||||||
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')
|
help='layers in network')
|
||||||
parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
|
parser.add_argument("--netwidth", type=int, default=256,
|
||||||
parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')
|
help='channels per layer')
|
||||||
parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
|
parser.add_argument("--netdepth_fine", type=int, default=8,
|
||||||
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
|
help='layers in fine network')
|
||||||
parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
|
parser.add_argument("--netwidth_fine", type=int, default=256,
|
||||||
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
|
help='channels per layer in fine network')
|
||||||
parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
|
parser.add_argument("--N_rand", type=int, default=32*32*4,
|
||||||
parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
|
help='batch size (number of random rays per gradient step)')
|
||||||
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
|
parser.add_argument("--lrate", type=float, default=5e-4,
|
||||||
parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
|
help='learning rate')
|
||||||
|
parser.add_argument("--lrate_decay", type=int, default=250,
|
||||||
|
help='exponential learning rate decay (in 1000 steps)')
|
||||||
|
parser.add_argument("--chunk", type=int, default=1024*32,
|
||||||
|
help='number of rays processed in parallel, decrease if running out of memory')
|
||||||
|
parser.add_argument("--netchunk", type=int, default=1024*64,
|
||||||
|
help='number of pts sent through network in parallel, decrease if running out of memory')
|
||||||
|
parser.add_argument("--no_batching", action='store_true',
|
||||||
|
help='only take random rays from 1 image at a time')
|
||||||
|
parser.add_argument("--no_reload", action='store_true',
|
||||||
|
help='do not reload weights from saved ckpt')
|
||||||
|
parser.add_argument("--ft_path", type=str, default=None,
|
||||||
|
help='specific weights npy file to reload for coarse network')
|
||||||
|
|
||||||
# rendering options
|
# rendering options
|
||||||
parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
|
parser.add_argument("--N_samples", type=int, default=64,
|
||||||
parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray')
|
help='number of coarse samples per ray')
|
||||||
parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
|
parser.add_argument("--N_importance", type=int, default=0,
|
||||||
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
|
help='number of additional fine samples per ray')
|
||||||
parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
|
parser.add_argument("--perturb", type=float, default=1.,
|
||||||
parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
|
help='set to 0. for no jitter, 1. for jitter')
|
||||||
parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
|
parser.add_argument("--use_viewdirs", action='store_true',
|
||||||
parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
|
help='use full 5D input instead of 3D')
|
||||||
|
parser.add_argument("--i_embed", type=int, default=0,
|
||||||
|
help='set 0 for default positional encoding, -1 for none')
|
||||||
|
parser.add_argument("--multires", type=int, default=10,
|
||||||
|
help='log2 of max freq for positional encoding (3D location)')
|
||||||
|
parser.add_argument("--multires_views", type=int, default=4,
|
||||||
|
help='log2 of max freq for positional encoding (2D direction)')
|
||||||
|
parser.add_argument("--raw_noise_std", type=float, default=0.,
|
||||||
|
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
|
||||||
|
|
||||||
parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
|
parser.add_argument("--render_only", action='store_true',
|
||||||
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
|
help='do not optimize, reload weights and render out render_poses path')
|
||||||
parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
|
parser.add_argument("--render_test", action='store_true',
|
||||||
|
help='render the test set instead of render_poses path')
|
||||||
|
parser.add_argument("--render_factor", type=int, default=0,
|
||||||
|
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
|
||||||
|
|
||||||
# dataset options
|
# dataset options
|
||||||
parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')
|
parser.add_argument("--dataset_type", type=str, default='llff',
|
||||||
parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
help='options: llff / blender / deepvoxels')
|
||||||
|
parser.add_argument("--testskip", type=int, default=8,
|
||||||
|
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
|
||||||
|
|
||||||
## deepvoxels flags
|
## deepvoxels flags
|
||||||
parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase')
|
parser.add_argument("--shape", type=str, default='greek',
|
||||||
|
help='options : armchair / cube / greek / vase')
|
||||||
|
|
||||||
## blender flags
|
## blender flags
|
||||||
parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
|
parser.add_argument("--white_bkgd", action='store_true',
|
||||||
parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
|
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
|
||||||
|
parser.add_argument("--half_res", action='store_true',
|
||||||
|
help='load blender synthetic data at 400x400 instead of 800x800')
|
||||||
|
|
||||||
## llff flags
|
## llff flags
|
||||||
parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
|
parser.add_argument("--factor", type=int, default=8,
|
||||||
parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
|
help='downsample factor for LLFF images')
|
||||||
parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
|
parser.add_argument("--no_ndc", action='store_true',
|
||||||
parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
|
help='do not use normalized device coordinates (set for non-forward facing scenes)')
|
||||||
parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
|
parser.add_argument("--lindisp", action='store_true',
|
||||||
|
help='sampling linearly in disparity rather than depth')
|
||||||
|
parser.add_argument("--spherify", action='store_true',
|
||||||
|
help='set for spherical 360 scenes')
|
||||||
|
parser.add_argument("--llffhold", type=int, default=8,
|
||||||
|
help='will take every 1/N images as LLFF test set, paper uses 8')
|
||||||
|
|
||||||
# logging/saving options
|
# logging/saving options
|
||||||
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
|
parser.add_argument("--i_print", type=int, default=100,
|
||||||
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
|
help='frequency of console printout and metric loggin')
|
||||||
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
|
parser.add_argument("--i_img", type=int, default=500,
|
||||||
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
|
help='frequency of tensorboard image logging')
|
||||||
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
|
parser.add_argument("--i_weights", type=int, default=10000,
|
||||||
|
help='frequency of weight ckpt saving')
|
||||||
|
parser.add_argument("--i_testset", type=int, default=50000,
|
||||||
|
help='frequency of testset saving')
|
||||||
|
parser.add_argument("--i_video", type=int, default=50000,
|
||||||
|
help='frequency of render_poses video saving')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
|
||||||
parser = config_parser()
|
parser = config_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
|
|
||||||
if args.dataset_type == 'llff':
|
if args.dataset_type == 'llff':
|
||||||
|
@ -453,7 +560,6 @@ def train():
|
||||||
far = 1.
|
far = 1.
|
||||||
print('NEAR FAR', near, far)
|
print('NEAR FAR', near, far)
|
||||||
|
|
||||||
|
|
||||||
elif args.dataset_type == 'blender':
|
elif args.dataset_type == 'blender':
|
||||||
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
|
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
|
||||||
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
|
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
|
||||||
|
@ -467,7 +573,6 @@ def train():
|
||||||
else:
|
else:
|
||||||
images = images[...,:3]
|
images = images[...,:3]
|
||||||
|
|
||||||
|
|
||||||
elif args.dataset_type == 'deepvoxels':
|
elif args.dataset_type == 'deepvoxels':
|
||||||
|
|
||||||
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
|
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
|
||||||
|
@ -481,7 +586,6 @@ def train():
|
||||||
near = hemi_R-1.
|
near = hemi_R-1.
|
||||||
far = hemi_R+1.
|
far = hemi_R+1.
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print('Unknown dataset type', args.dataset_type, 'exiting')
|
print('Unknown dataset type', args.dataset_type, 'exiting')
|
||||||
return
|
return
|
||||||
|
@ -494,7 +598,6 @@ def train():
|
||||||
if args.render_test:
|
if args.render_test:
|
||||||
render_poses = np.array(poses[i_test])
|
render_poses = np.array(poses[i_test])
|
||||||
|
|
||||||
|
|
||||||
# Create log dir and copy the config file
|
# Create log dir and copy the config file
|
||||||
basedir = args.basedir
|
basedir = args.basedir
|
||||||
expname = args.expname
|
expname = args.expname
|
||||||
|
@ -509,7 +612,6 @@ def train():
|
||||||
with open(f, 'w') as file:
|
with open(f, 'w') as file:
|
||||||
file.write(open(args.config, 'r').read())
|
file.write(open(args.config, 'r').read())
|
||||||
|
|
||||||
|
|
||||||
# Create nerf model
|
# Create nerf model
|
||||||
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
|
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
|
||||||
global_step = start
|
global_step = start
|
||||||
|
@ -631,9 +733,6 @@ def train():
|
||||||
psnr0 = mse2psnr(img_loss0)
|
psnr0 = mse2psnr(img_loss0)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# NOTE: same as tf till here - 04/03/2020
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# NOTE: IMPORTANT!
|
# NOTE: IMPORTANT!
|
||||||
|
|
Loading…
Reference in a new issue