diff --git a/run_nerf.py b/run_nerf.py index 695cd6e..a3e1f37 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -482,6 +482,12 @@ def config_parser(): parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + # training options + parser.add_argument("--precrop_iters", type=int, default=0, + help='number of steps to train on central crops') + parser.add_argument("--precrop_frac", type=float, + default=.5, help='fraction of img taken for central crops') + # dataset options parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels') @@ -707,7 +713,20 @@ def train(): if N_rand is not None: rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) - coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + + if i < args.precrop_iters: + dH = int(H//2 * args.precrop_frac) + dW = int(W//2 * args.precrop_frac) + coords = torch.stack( + torch.meshgrid( + torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), + torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) + ), -1) + if i == start: + print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") + else: + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + coords = torch.reshape(coords, [-1,2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].long() # (N_rand, 2)