clean code

This commit is contained in:
Kai-46 2020-10-11 22:25:41 -04:00
parent 2af283ff67
commit 6243abf840
2 changed files with 2 additions and 112 deletions

View file

@ -224,7 +224,6 @@ def render_single_image(rank, world_size, models, ray_sampler, chunk_size):
for key in ret_merge_chunk[m]:
torch.distributed.gather(ret_merge_chunk[m][key])
# only rank 0 program returns
if rank == 0:
return ret_merge_rank
@ -488,36 +487,30 @@ def config_parser():
parser.add_argument('--config', is_config_file=True, help='config file path')
parser.add_argument("--expname", type=str, help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
# dataset options
parser.add_argument("--datadir", type=str, default=None, help='input data directory')
parser.add_argument("--scene", type=str, default=None, help='scene name')
parser.add_argument("--testskip", type=int, default=8,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
# model size
parser.add_argument("--netdepth", type=int, default=8, help='layers in coarse network')
parser.add_argument("--netwidth", type=int, default=256, help='channels per layer in coarse network')
parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D')
# checkpoints
parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')
parser.add_argument("--ckpt_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
# batch size
parser.add_argument("--N_rand", type=int, default=32 * 32 * 2,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--chunk_size", type=int, default=1024 * 8,
help='number of rays processed in parallel, decrease if running out of memory')
# iterations
parser.add_argument("--N_iters", type=int, default=250001,
help='number of iterations')
# render only
parser.add_argument("--render_splits", type=str, default='test',
help='splits to render')
# cascade training
parser.add_argument("--cascade_level", type=int, default=2,
help='number of cascade levels')
@ -527,26 +520,15 @@ def config_parser():
help='cuda device for each level')
parser.add_argument("--bg_devices", type=str, default='0,2',
help='cuda device for the background of each level')
# multiprocess learning
parser.add_argument("--world_size", type=int, default='-1',
help='number of processes')
# mixed precison training
parser.add_argument("--opt_level", type=str, default='O1',
help='mixed precison training')
parser.add_argument("--near_depth", type=float, default=0.1,
help='near depth plane')
parser.add_argument("--far_depth", type=float, default=50.,
help='far depth plane')
# learning rate options
parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')
parser.add_argument("--lrate_decay_factor", type=float, default=0.1,
help='decay learning rate by a factor every specified number of steps')
parser.add_argument("--lrate_decay_steps", type=int, default=5000,
help='decay learning rate by a factor every specified number of steps')
# rendering options
parser.add_argument("--inv_uniform", action='store_true',
help='if True, will uniformly sample inverse depths')
@ -555,36 +537,11 @@ def config_parser():
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--max_freq_log2_viewdirs", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--N_iters_perturb", type=int, default=1000,
help='perturb and center-crop at first 1000 iterations to prevent training from getting stuck')
parser.add_argument("--raw_noise_std", type=float, default=1.,
help='std dev of noise added to regularize sigma output, 1e0 recommended')
parser.add_argument("--white_bkgd", action='store_true',
help='apply the trick to avoid fitting to white background')
# use implicit
parser.add_argument("--load_min_depth", action='store_true', help='whether to load min depth')
# no training; render only
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_train", action='store_true', help='render the training set')
parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path')
# no training; extract mesh only
parser.add_argument("--mesh_only", action='store_true',
help='do not optimize, extract mesh from pretrained model')
parser.add_argument("--N_pts", type=int, default=256,
help='voxel resolution; N_pts * N_pts * N_pts')
parser.add_argument("--mesh_thres", type=str, default='10,20,30,40,50',
help='threshold(s) for mesh extraction; can use multiple thresholds')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving')
return parser

View file

@ -180,70 +180,3 @@ class RaySamplerSingleImage(object):
ret[k] = torch.from_numpy(ret[k])
return ret
# def random_sample_patches(self, N_patch, r_patch=16, center_crop=False):
# '''
# :param N_patch: number of patches to be sampled
# :param r_patch: patch size will be (2*r_patch+1)*(2*r_patch+1)
# :return:
# '''
# # even size patch
# # offsets to center pixels
# u, v = np.meshgrid(np.arange(-r_patch, r_patch),
# np.arange(-r_patch, r_patch))
# u = u.reshape(-1)
# v = v.reshape(-1)
# offsets = v * self.W + u
# # center pixel coordinates
# u_min = r_patch
# u_max = self.W - r_patch
# v_min = r_patch
# v_max = self.H - r_patch
# if center_crop:
# u_min = self.W // 4 + r_patch
# u_max = self.W - self.W // 4 - r_patch
# v_min = self.H // 4 + r_patch
# v_max = self.H - self.H // 4 - r_patch
# u, v = np.meshgrid(np.arange(u_min, u_max, r_patch),
# np.arange(v_min, v_max, r_patch))
# u = u.reshape(-1)
# v = v.reshape(-1)
# select_inds = np.random.choice(u.shape[0], size=(N_patch,), replace=False)
# # Convert back to original image
# select_inds = v[select_inds] * self.W + u[select_inds]
# # pick patches
# select_inds = np.stack([select_inds + shift for shift in offsets], axis=1)
# select_inds = select_inds.reshape(-1)
# rays_o = self.rays_o[select_inds, :] # [N_rand, 3]
# rays_d = self.rays_d[select_inds, :] # [N_rand, 3]
# depth = self.depth[select_inds] # [N_rand, ]
# if self.img is not None:
# rgb = self.img[select_inds, :] # [N_rand, 3]
# # ### debug
# # import imageio
# # imgs = rgb.reshape((N_patch, r_patch*2, r_patch*2, -1))
# # for kk in range(imgs.shape[0]):
# # imageio.imwrite('./debug_{}.png'.format(kk), imgs[kk])
# # ###
# else:
# rgb = None
# ret = OrderedDict([
# ('ray_o', rays_o),
# ('ray_d', rays_d),
# ('depth', depth),
# ('rgb', rgb)
# ])
# # return torch tensors
# for k in ret:
# ret[k] = torch.from_numpy(ret[k])
# return ret