diff --git a/run_nerf.py b/run_nerf.py index 3633e18..ead5044 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -25,6 +25,8 @@ DEBUG = False def batchify(fn, chunk): + """Constructs a version of 'fn' that applies to smaller batches. + """ if chunk is None: return fn def ret(inputs): @@ -33,7 +35,8 @@ def batchify(fn, chunk): def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): - + """Prepares inputs and applies network 'fn'. + """ inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) embedded = embed_fn(inputs_flat) @@ -49,7 +52,8 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): def batchify_rays(rays_flat, chunk=1024*32, **kwargs): - + """Render rays in smaller minibatches to avoid OOM. + """ all_ret = {} for i in range(0, rays_flat.shape[0], chunk): ret = render_rays(rays_flat[i:i+chunk], **kwargs) @@ -66,7 +70,28 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True, near=0., far=1., use_viewdirs=False, c2w_staticcam=None, **kwargs): - + """Render rays + Args: + H: int. Height of image in pixels. + W: int. Width of image in pixels. + focal: float. Focal length of pinhole camera. + chunk: int. Maximum number of rays to process simultaneously. Used to + control maximum memory usage. Does not affect final results. + rays: array of shape [2, batch_size, 3]. Ray origin and direction for + each example in batch. + c2w: array of shape [3, 4]. Camera-to-world transformation matrix. + ndc: bool. If True, represent ray origin, direction in NDC coordinates. + near: float or array of shape [batch_size]. Nearest distance for a ray. + far: float or array of shape [batch_size]. Farthest distance for a ray. + use_viewdirs: bool. If True, use viewing direction of a point in space in model. + c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for + camera while using other c2w argument for viewing directions. + Returns: + rgb_map: [batch_size, 3]. Predicted RGB values for rays. + disp_map: [batch_size]. Disparity map. Inverse of depth. + acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. + extras: dict with everything returned by render_rays(). + """ if c2w is not None: # special case to render full image rays_o, rays_d = get_rays(H, W, focal, c2w) @@ -151,6 +176,8 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N def create_nerf(args): + """Instantiate NeRF's MLP model. + """ embed_fn, input_ch = get_embedder(args.multires, args.i_embed) input_ch_views = 0 @@ -233,7 +260,17 @@ def create_nerf(args): def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): - """ A helper function for `render_rays`. + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. """ raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) @@ -281,6 +318,36 @@ def render_rays(ray_batch, raw_noise_std=0., verbose=False, pytest=False): + """Volumetric rendering. + Args: + ray_batch: array of shape [batch_size, ...]. All information necessary + for sampling along a ray, including: ray origin, ray direction, min + dist, max dist, and unit-magnitude viewing direction. + network_fn: function. Model for predicting RGB and density at each point + in space. + network_query_fn: function used for passing queries to network_fn. + N_samples: int. Number of different times to sample along each ray. + retraw: bool. If True, include model's raw, unprocessed predictions. + lindisp: bool. If True, sample linearly in inverse depth rather than in depth. + perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified + random points in time. + N_importance: int. Number of additional times to sample along each ray. + These samples are only passed to network_fine. + network_fine: "fine" network with same spec as network_fn. + white_bkgd: bool. If True, assume a white background. + raw_noise_std: ... + verbose: bool. If True, print more debugging info. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. + disp_map: [num_rays]. Disparity map. 1 / depth. + acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. + raw: [num_rays, num_samples, 4]. Raw predictions from model. + rgb0: See rgb_map. Output for coarse model. + disp0: See disp_map. Output for coarse model. + acc0: See acc_map. Output for coarse model. + z_std: [num_rays]. Standard deviation of distances along ray for each + sample. + """ N_rays = ray_batch.shape[0] rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None @@ -355,74 +422,114 @@ def config_parser(): import configargparse parser = configargparse.ArgumentParser() - parser.add_argument('--config', is_config_file=True, help='config file path') - parser.add_argument("--expname", type=str, help='experiment name') - parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') - parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') + parser.add_argument('--config', is_config_file=True, + help='config file path') + parser.add_argument("--expname", type=str, + help='experiment name') + parser.add_argument("--basedir", type=str, default='./logs/', + help='where to store ckpts and logs') + parser.add_argument("--datadir", type=str, default='./data/llff/fern', + help='input data directory') # training options - parser.add_argument("--netdepth", type=int, default=8, help='layers in network') - parser.add_argument("--netwidth", type=int, default=256, help='channels per layer') - parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') - parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network') - parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)') - parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') - parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)') - parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') - parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') - parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') - parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') - parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') + parser.add_argument("--netdepth", type=int, default=8, + help='layers in network') + parser.add_argument("--netwidth", type=int, default=256, + help='channels per layer') + parser.add_argument("--netdepth_fine", type=int, default=8, + help='layers in fine network') + parser.add_argument("--netwidth_fine", type=int, default=256, + help='channels per layer in fine network') + parser.add_argument("--N_rand", type=int, default=32*32*4, + help='batch size (number of random rays per gradient step)') + parser.add_argument("--lrate", type=float, default=5e-4, + help='learning rate') + parser.add_argument("--lrate_decay", type=int, default=250, + help='exponential learning rate decay (in 1000 steps)') + parser.add_argument("--chunk", type=int, default=1024*32, + help='number of rays processed in parallel, decrease if running out of memory') + parser.add_argument("--netchunk", type=int, default=1024*64, + help='number of pts sent through network in parallel, decrease if running out of memory') + parser.add_argument("--no_batching", action='store_true', + help='only take random rays from 1 image at a time') + parser.add_argument("--no_reload", action='store_true', + help='do not reload weights from saved ckpt') + parser.add_argument("--ft_path", type=str, default=None, + help='specific weights npy file to reload for coarse network') # rendering options - parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') - parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray') - parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') - parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') - parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') - parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') - parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') - parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') + parser.add_argument("--N_samples", type=int, default=64, + help='number of coarse samples per ray') + parser.add_argument("--N_importance", type=int, default=0, + help='number of additional fine samples per ray') + parser.add_argument("--perturb", type=float, default=1., + help='set to 0. for no jitter, 1. for jitter') + parser.add_argument("--use_viewdirs", action='store_true', + help='use full 5D input instead of 3D') + parser.add_argument("--i_embed", type=int, default=0, + help='set 0 for default positional encoding, -1 for none') + parser.add_argument("--multires", type=int, default=10, + help='log2 of max freq for positional encoding (3D location)') + parser.add_argument("--multires_views", type=int, default=4, + help='log2 of max freq for positional encoding (2D direction)') + parser.add_argument("--raw_noise_std", type=float, default=0., + help='std dev of noise added to regularize sigma_a output, 1e0 recommended') - parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') - parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') - parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + parser.add_argument("--render_only", action='store_true', + help='do not optimize, reload weights and render out render_poses path') + parser.add_argument("--render_test", action='store_true', + help='render the test set instead of render_poses path') + parser.add_argument("--render_factor", type=int, default=0, + help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') # dataset options - parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels') - parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') + parser.add_argument("--dataset_type", type=str, default='llff', + help='options: llff / blender / deepvoxels') + parser.add_argument("--testskip", type=int, default=8, + help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') ## deepvoxels flags - parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase') + parser.add_argument("--shape", type=str, default='greek', + help='options : armchair / cube / greek / vase') ## blender flags - parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') - parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') + parser.add_argument("--white_bkgd", action='store_true', + help='set to render synthetic data on a white bkgd (always use for dvoxels)') + parser.add_argument("--half_res", action='store_true', + help='load blender synthetic data at 400x400 instead of 800x800') ## llff flags - parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') - parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') - parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') - parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') - parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') + parser.add_argument("--factor", type=int, default=8, + help='downsample factor for LLFF images') + parser.add_argument("--no_ndc", action='store_true', + help='do not use normalized device coordinates (set for non-forward facing scenes)') + parser.add_argument("--lindisp", action='store_true', + help='sampling linearly in disparity rather than depth') + parser.add_argument("--spherify", action='store_true', + help='set for spherical 360 scenes') + parser.add_argument("--llffhold", type=int, default=8, + help='will take every 1/N images as LLFF test set, paper uses 8') # logging/saving options - parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') - parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') - parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') - parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving') - parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') + parser.add_argument("--i_print", type=int, default=100, + help='frequency of console printout and metric loggin') + parser.add_argument("--i_img", type=int, default=500, + help='frequency of tensorboard image logging') + parser.add_argument("--i_weights", type=int, default=10000, + help='frequency of weight ckpt saving') + parser.add_argument("--i_testset", type=int, default=50000, + help='frequency of testset saving') + parser.add_argument("--i_video", type=int, default=50000, + help='frequency of render_poses video saving') return parser - def train(): parser = config_parser() args = parser.parse_args() - # Load data if args.dataset_type == 'llff': @@ -453,7 +560,6 @@ def train(): far = 1. print('NEAR FAR', near, far) - elif args.dataset_type == 'blender': images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) @@ -467,7 +573,6 @@ def train(): else: images = images[...,:3] - elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, @@ -481,7 +586,6 @@ def train(): near = hemi_R-1. far = hemi_R+1. - else: print('Unknown dataset type', args.dataset_type, 'exiting') return @@ -494,7 +598,6 @@ def train(): if args.render_test: render_poses = np.array(poses[i_test]) - # Create log dir and copy the config file basedir = args.basedir expname = args.expname @@ -509,7 +612,6 @@ def train(): with open(f, 'w') as file: file.write(open(args.config, 'r').read()) - # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) global_step = start @@ -631,9 +733,6 @@ def train(): psnr0 = mse2psnr(img_loss0) loss.backward() - - # NOTE: same as tf till here - 04/03/2020 - optimizer.step() # NOTE: IMPORTANT!