Add crop
This commit is contained in:
parent
7158181ee4
commit
c9074d8b60
1 changed files with 20 additions and 1 deletions
21
run_nerf.py
21
run_nerf.py
|
@ -482,6 +482,12 @@ def config_parser():
|
||||||
parser.add_argument("--render_factor", type=int, default=0,
|
parser.add_argument("--render_factor", type=int, default=0,
|
||||||
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
|
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
|
||||||
|
|
||||||
|
# training options
|
||||||
|
parser.add_argument("--precrop_iters", type=int, default=0,
|
||||||
|
help='number of steps to train on central crops')
|
||||||
|
parser.add_argument("--precrop_frac", type=float,
|
||||||
|
default=.5, help='fraction of img taken for central crops')
|
||||||
|
|
||||||
# dataset options
|
# dataset options
|
||||||
parser.add_argument("--dataset_type", type=str, default='llff',
|
parser.add_argument("--dataset_type", type=str, default='llff',
|
||||||
help='options: llff / blender / deepvoxels')
|
help='options: llff / blender / deepvoxels')
|
||||||
|
@ -707,7 +713,20 @@ def train():
|
||||||
|
|
||||||
if N_rand is not None:
|
if N_rand is not None:
|
||||||
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
|
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
|
||||||
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
|
|
||||||
|
if i < args.precrop_iters:
|
||||||
|
dH = int(H//2 * args.precrop_frac)
|
||||||
|
dW = int(W//2 * args.precrop_frac)
|
||||||
|
coords = torch.stack(
|
||||||
|
torch.meshgrid(
|
||||||
|
torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
|
||||||
|
torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
|
||||||
|
), -1)
|
||||||
|
if i == start:
|
||||||
|
print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
|
||||||
|
else:
|
||||||
|
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
|
||||||
|
|
||||||
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
|
coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
|
||||||
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
|
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
|
||||||
select_coords = coords[select_inds].long() # (N_rand, 2)
|
select_coords = coords[select_inds].long() # (N_rand, 2)
|
||||||
|
|
Loading…
Reference in a new issue