From c9074d8b60e0232b3af6976e076223a9e68d6e57 Mon Sep 17 00:00:00 2001 From: Yen-Chen Lin Date: Fri, 17 Apr 2020 22:34:48 -0400 Subject: [PATCH] Add crop --- run_nerf.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/run_nerf.py b/run_nerf.py index 695cd6e..a3e1f37 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -482,6 +482,12 @@ def config_parser(): parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + # training options + parser.add_argument("--precrop_iters", type=int, default=0, + help='number of steps to train on central crops') + parser.add_argument("--precrop_frac", type=float, + default=.5, help='fraction of img taken for central crops') + # dataset options parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels') @@ -707,7 +713,20 @@ def train(): if N_rand is not None: rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) - coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + + if i < args.precrop_iters: + dH = int(H//2 * args.precrop_frac) + dW = int(W//2 * args.precrop_frac) + coords = torch.stack( + torch.meshgrid( + torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), + torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) + ), -1) + if i == start: + print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") + else: + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + coords = torch.reshape(coords, [-1,2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].long() # (N_rand, 2)