Add doc string

master
Yen-Chen Lin 4 years ago
parent f61ca730eb
commit c3ccc0bdb3

@ -25,6 +25,8 @@ DEBUG = False
def batchify(fn, chunk):
"""Constructs a version of 'fn' that applies to smaller batches.
"""
if chunk is None:
return fn
def ret(inputs):
@ -33,7 +35,8 @@ def batchify(fn, chunk):
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'.
"""
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
@ -49,7 +52,8 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(rays_flat[i:i+chunk], **kwargs)
@ -66,7 +70,28 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
focal: float. Focal length of pinhole camera.
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
near: float or array of shape [batch_size]. Nearest distance for a ray.
far: float or array of shape [batch_size]. Farthest distance for a ray.
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
disp_map: [batch_size]. Disparity map. Inverse of depth.
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
extras: dict with everything returned by render_rays().
"""
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, focal, c2w)
@ -151,6 +176,8 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
def create_nerf(args):
"""Instantiate NeRF's MLP model.
"""
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
input_ch_views = 0
@ -233,7 +260,17 @@ def create_nerf(args):
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
""" A helper function for `render_rays`.
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
@ -281,6 +318,36 @@ def render_rays(ray_batch,
raw_noise_std=0.,
verbose=False,
pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
@ -355,74 +422,114 @@ def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, help='config file path')
parser.add_argument("--expname", type=str, help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8, help='layers in network')
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8,
help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256,
help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4,
help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=250,
help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*32,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*64,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
# rendering options
parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
parser.add_argument("--N_samples", type=int, default=64,
help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0,
help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true',
help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
# dataset options
parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
parser.add_argument("--dataset_type", type=str, default='llff',
help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
## deepvoxels flags
parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase')
parser.add_argument("--shape", type=str, default='greek',
help='options : armchair / cube / greek / vase')
## blender flags
parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800')
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true',
help='load blender synthetic data at 400x400 instead of 800x800')
## llff flags
parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')
parser.add_argument("--factor", type=int, default=8,
help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true',
help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true',
help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true',
help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8,
help='will take every 1/N images as LLFF test set, paper uses 8')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
parser.add_argument("--i_print", type=int, default=100,
help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500,
help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000,
help='frequency of render_poses video saving')
return parser
def train():
parser = config_parser()
args = parser.parse_args()
# Load data
if args.dataset_type == 'llff':
@ -453,7 +560,6 @@ def train():
far = 1.
print('NEAR FAR', near, far)
elif args.dataset_type == 'blender':
images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
@ -467,7 +573,6 @@ def train():
else:
images = images[...,:3]
elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
@ -481,7 +586,6 @@ def train():
near = hemi_R-1.
far = hemi_R+1.
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
@ -494,7 +598,6 @@ def train():
if args.render_test:
render_poses = np.array(poses[i_test])
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
@ -509,7 +612,6 @@ def train():
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
global_step = start
@ -631,9 +733,6 @@ def train():
psnr0 = mse2psnr(img_loss0)
loss.backward()
# NOTE: same as tf till here - 04/03/2020
optimizer.step()
# NOTE: IMPORTANT!

Loading…
Cancel
Save