From 7158181ee44bb8a888887c673feab4d57a94f2c3 Mon Sep 17 00:00:00 2001 From: Yen-Chen Lin Date: Fri, 17 Apr 2020 22:01:16 -0400 Subject: [PATCH] Better training interface --- run_nerf.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/run_nerf.py b/run_nerf.py index ead5044..695cd6e 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm +from tqdm import tqdm, trange import matplotlib.pyplot as plt @@ -673,7 +673,7 @@ def train(): rays_rgb = torch.Tensor(rays_rgb).to(device) - N_iters = 1000000 + N_iters = 200000 + 1 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) @@ -682,7 +682,7 @@ def train(): # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) - for i in range(start, N_iters): + for i in trange(start, N_iters): time0 = time.time() # Sample random ray batch @@ -745,7 +745,7 @@ def train(): ################################ dt = time.time()-time0 - print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") + # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") ##### end ##### # Rest is logging @@ -784,11 +784,13 @@ def train(): print('Saved test set') - """ + if i%args.i_print==0 or i < 10: - + tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") + """ print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) + with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('psnr', psnr)